54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import torch
|
|
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
|
|
|
|
import gc
|
|
|
|
cache_dir = "./model_cache"
|
|
|
|
print("Freeing up memory...")
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
model_id = "mistralai/Ministral-3-3B-Instruct-2512"
|
|
|
|
tokenizer = MistralCommonBackend.from_pretrained(model_id,cache_dir=cache_dir)
|
|
model = Mistral3ForConditionalGeneration.from_pretrained(model_id, device_map="auto",cache_dir=cache_dir)
|
|
|
|
print(f"Model device: {next(model.parameters()).device}")
|
|
print(f"Model dtype: {model.dtype}")
|
|
|
|
|
|
image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
|
|
},
|
|
{"type": "image_url", "image_url": {"url": image_url}},
|
|
],
|
|
},
|
|
]
|
|
|
|
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
|
|
|
|
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda") #vedi che non gli passa immagine PIL come fa per gemma3
|
|
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
|
|
image_sizes = [tokenized["pixel_values"].shape[-2:]]
|
|
print("Image sizes:", image_sizes)
|
|
|
|
|
|
print(f"model device: {model.device}")
|
|
tokenized= tokenized.to(model.device)#sposto tutto sul device del modello
|
|
|
|
output = model.generate(
|
|
**tokenized,
|
|
image_sizes=image_sizes,
|
|
max_new_tokens=512,
|
|
)[0]
|
|
|
|
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
|
|
print(decoded_output) |