341 lines
12 KiB
Python
341 lines
12 KiB
Python
from huggingface_hub import login
|
|
import os
|
|
import gc
|
|
|
|
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
# Login into Hugging Face Hub
|
|
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
|
print("Logging into Hugging Face Hub...")
|
|
login(hf_token)
|
|
print("Logged in.")
|
|
from datasets import load_dataset
|
|
from PIL import Image
|
|
|
|
# System message for the assistant
|
|
system_message = "You are an expert product description writer for Amazon."
|
|
print("System message set.")
|
|
|
|
# User prompt that combines the user query and the schema
|
|
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
|
|
Only return description. The description should be SEO optimized and for a better mobile search experience.
|
|
|
|
<PRODUCT>
|
|
{product}
|
|
</PRODUCT>
|
|
|
|
<CATEGORY>
|
|
{category}
|
|
</CATEGORY>
|
|
"""
|
|
|
|
# Convert dataset to OAI messages
|
|
def format_data(sample):
|
|
return {
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_message}],#esempio unsloth non ha system message
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": user_prompt.format(
|
|
product=sample["Product Name"],
|
|
category=sample["Category"],
|
|
),
|
|
},
|
|
{
|
|
"type": "image",
|
|
"image": sample["image"],
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": sample["description"]}],#vedi ruolo assistente per la risposta aspettata
|
|
},
|
|
],
|
|
}
|
|
|
|
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
|
image_inputs = []
|
|
# Iterate through each conversation
|
|
for msg in messages:
|
|
# Get content (ensure it's a list)
|
|
content = msg.get("content", [])
|
|
if not isinstance(content, list):
|
|
content = [content]
|
|
|
|
# Check each content element for images
|
|
for element in content:
|
|
if isinstance(element, dict) and (
|
|
"image" in element or element.get("type") == "image"
|
|
):
|
|
# Get the image and convert to RGB
|
|
if "image" in element:
|
|
image = element["image"]
|
|
else:
|
|
image = element
|
|
image_inputs.append(image.convert("RGB"))#converte in rgb !
|
|
return image_inputs
|
|
|
|
print("Loading dataset...")
|
|
# Load dataset from the hub
|
|
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train",cache_dir="./dataset_cache")
|
|
|
|
# Convert dataset to OAI messages
|
|
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
|
|
dataset = [format_data(sample) for sample in dataset]
|
|
|
|
print(dataset[345]["messages"])
|
|
|
|
import torch
|
|
torch.cuda.get_device_capability()
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
# Get free memory in bytes
|
|
free_memory = torch.cuda.mem_get_info()[0]
|
|
total_memory = torch.cuda.mem_get_info()[1]
|
|
|
|
# Convert to GB for readability
|
|
free_gb = free_memory / (1024**3)
|
|
total_gb = total_memory / (1024**3)
|
|
|
|
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
|
|
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
|
|
|
# Hugging Face model id
|
|
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
|
|
|
|
# Check if GPU benefits from bfloat16
|
|
#if torch.cuda.get_device_capability()[0] < 8:
|
|
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
|
|
|
|
# Define model init arguments
|
|
model_kwargs = dict(
|
|
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
|
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
|
device_map="auto", # Let torch decide how to load the model
|
|
|
|
)
|
|
|
|
# BitsAndBytesConfig int-4 config
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
|
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
|
)
|
|
|
|
# Load model and tokenizer
|
|
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
|
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
|
|
|
|
|
|
|
|
|
# Set the cache directory to current folder
|
|
cache_dir = "./model_cache" # or just "." for directly in current folder
|
|
|
|
print("Loading model... This may take a while.")
|
|
model = AutoModelForImageTextToText.from_pretrained(
|
|
model_id,
|
|
cache_dir=cache_dir,
|
|
**model_kwargs
|
|
)
|
|
print("Model loaded.")
|
|
|
|
|
|
proc_cache_dir = "./proc_cache"
|
|
print("Loading processor...")
|
|
processor = AutoProcessor.from_pretrained(
|
|
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
|
cache_dir=proc_cache_dir
|
|
)
|
|
print("Processor loaded.")
|
|
|
|
# Download and save to current folder
|
|
print("Saving model and processor locally...")
|
|
save_path = "./local_model"
|
|
#model.save_pretrained(save_path)
|
|
#processor.save_pretrained(save_path)
|
|
print("Model and processor saved.")
|
|
|
|
from peft import LoraConfig
|
|
|
|
peft_config = LoraConfig(
|
|
lora_alpha=16,
|
|
lora_dropout=0.05,
|
|
r=16,
|
|
bias="none",
|
|
target_modules="all-linear",
|
|
task_type="CAUSAL_LM",
|
|
#modules_to_save=[ #quello che mi prendeva
|
|
# "lm_head",
|
|
# "embed_tokens",
|
|
#],
|
|
)
|
|
|
|
from trl import SFTConfig
|
|
|
|
args = SFTConfig(
|
|
output_dir="./gemma-finetuned", # directory to save and repository id
|
|
num_train_epochs=1, # number of training epochs
|
|
per_device_train_batch_size=1, # batch size per device during training
|
|
gradient_accumulation_steps=8,#4, # number of steps before performing a backward/update pass
|
|
gradient_checkpointing=True, # use gradient checkpointing to save memory
|
|
optim="adamw_torch_fused", # use fused adamw optimizer
|
|
logging_steps=5, # log every 5 steps
|
|
save_strategy="epoch", # save checkpoint every epoch
|
|
learning_rate=2e-4, # learning rate, based on QLoRA paper
|
|
bf16=True,#False,#True, # use bfloat16 precision
|
|
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
|
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
|
lr_scheduler_type="constant", # use constant learning rate scheduler
|
|
push_to_hub=True, # push model to hub
|
|
report_to="tensorboard", # report metrics to tensorboard
|
|
gradient_checkpointing_kwargs={
|
|
"use_reentrant": False
|
|
}, # use reentrant checkpointing
|
|
dataset_text_field="", # need a dummy field for collator
|
|
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
|
|
)
|
|
args.remove_unused_columns = False # important for collator
|
|
|
|
# Create a data collator to encode text and image pairs
|
|
def collate_fn(examples):
|
|
texts = []
|
|
images = []
|
|
for example in examples:
|
|
image_inputs = process_vision_info(example["messages"])
|
|
text = processor.apply_chat_template(
|
|
example["messages"], add_generation_prompt=False, tokenize=False
|
|
)
|
|
texts.append(text.strip())
|
|
images.append(image_inputs)
|
|
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
|
|
labels = batch["input_ids"].clone()
|
|
|
|
# Mask image tokens
|
|
image_token_id = [
|
|
processor.tokenizer.convert_tokens_to_ids(
|
|
processor.tokenizer.special_tokens_map["boi_token"]
|
|
)
|
|
]
|
|
# Mask tokens for not being used in the loss computation
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
labels[labels == image_token_id] = -100
|
|
labels[labels == 262144] = -100
|
|
|
|
batch["labels"] = labels
|
|
return batch
|
|
|
|
from trl import SFTTrainer
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=dataset,
|
|
peft_config=peft_config,
|
|
processing_class=processor,
|
|
data_collator=collate_fn,
|
|
)
|
|
|
|
# Start training, the model will be automatically saved to the Hub and the output directory
|
|
trainer.train()
|
|
|
|
# Save the final model again to the Hugging Face Hub
|
|
trainer.save_model()
|
|
|
|
# free the memory again
|
|
del model
|
|
del trainer
|
|
torch.cuda.empty_cache()
|
|
|
|
from peft import PeftModel
|
|
|
|
# Load Model base model
|
|
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)
|
|
|
|
# Merge LoRA and base model and save
|
|
peft_model = PeftModel.from_pretrained(model, args.output_dir)
|
|
merged_model = peft_model.merge_and_unload()
|
|
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
|
|
|
|
processor = AutoProcessor.from_pretrained(args.output_dir)
|
|
processor.save_pretrained("merged_model")
|
|
|
|
import torch
|
|
|
|
# Load Model with PEFT adapter
|
|
model = AutoModelForImageTextToText.from_pretrained(
|
|
args.output_dir,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="eager",
|
|
)
|
|
processor = AutoProcessor.from_pretrained(args.output_dir)
|
|
|
|
import requests
|
|
from PIL import Image
|
|
|
|
# Test sample with Product Name, Category and Image
|
|
sample = {
|
|
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
|
|
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
|
|
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
|
|
}
|
|
|
|
|
|
# NB: inferenza fatta con input immagine e i due campi testuali (e stessa instruction del finetuning)
|
|
def generate_description(sample, model, processor):
|
|
# Convert sample into messages and then apply the chat template
|
|
messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
|
{"role": "user", "content": [
|
|
{"type": "image","image": sample["image"]},
|
|
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
|
|
]},
|
|
]
|
|
text = processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
# Process the image and text
|
|
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
|
|
# Tokenize the text and process the images
|
|
inputs = processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
# Move the inputs to the device
|
|
inputs = inputs.to(model.device)
|
|
|
|
# Generate the output
|
|
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
|
|
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
|
|
# Trim the generation and decode the output to text
|
|
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
|
output_text = processor.batch_decode(
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)
|
|
return output_text[0]
|
|
|
|
# generate the description
|
|
|
|
description = generate_description(sample, model, processor)
|
|
print(description) |