wcag_AI_validation/scripts/finetuning_inference_time_s.../ministral3.py

54 lines
1.7 KiB
Python

import torch
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
import gc
cache_dir = "./model_cache"
print("Freeing up memory...")
torch.cuda.empty_cache()
gc.collect()
model_id = "mistralai/Ministral-3-3B-Instruct-2512"
tokenizer = MistralCommonBackend.from_pretrained(model_id,cache_dir=cache_dir)
model = Mistral3ForConditionalGeneration.from_pretrained(model_id, device_map="auto",cache_dir=cache_dir)
print(f"Model device: {next(model.parameters()).device}")
print(f"Model dtype: {model.dtype}")
image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda") #vedi che non gli passa immagine PIL come fa per gemma3
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
image_sizes = [tokenized["pixel_values"].shape[-2:]]
print("Image sizes:", image_sizes)
print(f"model device: {model.device}")
tokenized= tokenized.to(model.device)#sposto tutto sul device del modello
output = model.generate(
**tokenized,
image_sizes=image_sizes,
max_new_tokens=512,
)[0]
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
print(decoded_output)