finetuning
This commit is contained in:
parent
f79e9d0667
commit
1706ed3a06
|
|
@ -91,6 +91,8 @@ class MLLMManager:
|
|||
If I could not use the image content, what words would I use to convey the same function and/or information?
|
||||
|
||||
When image content contains words that are important to understanding the content, the alt text should include those words.
|
||||
Decorative images don’t add information to the content of a page. For example, the information provided by the image might already be given using adjacent text, or the image might be included to make the website more visually attractive.
|
||||
In these cases, a null (empty) alt text should be provided (alt="") so that they can be ignored by assistive technologies, such as screen readers.
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
|
|
@ -103,7 +105,7 @@ class MLLMManager:
|
|||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment based on the following:
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
|
|
@ -111,14 +113,14 @@ class MLLMManager:
|
|||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim. Your response should be in English.
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same language as the original alt-text.
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
|
||||
return system_prompt
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import os
|
|||
import requests
|
||||
import base64
|
||||
import sqlite3
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
exception_msg = "Exception: %s"
|
||||
|
||||
|
|
@ -129,7 +131,21 @@ def create_folder(root_path, directory_separator, next_path):
|
|||
|
||||
def encode_image_from_url(image_url):
|
||||
response = requests.get(image_url)
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
# Open image and convert to RGB
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
|
||||
# Convert to RGB (handles RGBA, grayscale, etc.)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Save to bytes buffer
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='PNG') # or 'JPEG'
|
||||
buffer.seek(0)
|
||||
|
||||
# Encode to base64
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def db_persistence_startup(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,519 @@
|
|||
# to launch: python build_dataset_from_folder_full_features.py --ref_path "C:\cartella_condivisa\MachineLearning\HIISlab\accessibility\notebook_miei\LLM_accessibility_validator\out" --push_to_hub --repo_id "nicolaleo/LLM-alt-text-assessment-full-features" --token "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" --dataset_split "train" --dataset_name "alt_text_merged_dataset_full_features"
|
||||
|
||||
# create the dataset based on features aligned to the user_test (aligned to the inference task)
|
||||
"""
|
||||
[
|
||||
"image",
|
||||
"image_url",
|
||||
"original_alt_text",
|
||||
"llm_assessment",
|
||||
"llm_judgment",
|
||||
"llm_evaluation_result",
|
||||
"llm_alt_text",
|
||||
"page_url",
|
||||
"html_context",
|
||||
"page_title",
|
||||
"page_description",
|
||||
"page_keywords"
|
||||
]
|
||||
"""
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
import datasets
|
||||
import json
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import hashlib
|
||||
import urllib.parse
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SIMPLE USAGE FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def url_to_filename(image_url): # save step as in the image_extractor dependence
|
||||
"""
|
||||
Convert image URL to sanitized filename following your exact logic.
|
||||
|
||||
Args:
|
||||
image_url: The image URL
|
||||
|
||||
Returns:
|
||||
Sanitized filename with extension
|
||||
"""
|
||||
|
||||
# Parse the URL to get the path without query parameters
|
||||
parsed_url = urllib.parse.urlparse(image_url)
|
||||
url_path = parsed_url.path
|
||||
|
||||
# Get the filename from the path
|
||||
filename = url_path.split("/")[-1]
|
||||
print(f"Original filename: '{filename}'")
|
||||
|
||||
# Split filename and extension
|
||||
if "." in filename:
|
||||
image_name, ext = filename.rsplit(".", 1)
|
||||
ext = ext.lower()
|
||||
else:
|
||||
image_name = filename
|
||||
ext = "jpg"
|
||||
|
||||
# Validate extension
|
||||
if ext not in ["jpg", "jpeg", "png", "gif", "webp"]:
|
||||
ext = "jpg"
|
||||
|
||||
# Sanitize image name (remove special characters, limit length)
|
||||
image_name = "".join(c for c in image_name if c.isalnum() or c in ("-", "_"))
|
||||
|
||||
image_name = image_name[:50] # Limit filename length
|
||||
|
||||
# If name is empty after sanitization, create a hash-based name
|
||||
if not image_name:
|
||||
image_name = hashlib.md5(image_url.encode()).hexdigest()[:16]
|
||||
|
||||
return f"{image_name}.{ext}"
|
||||
|
||||
|
||||
def push_to_hub_example(dataset_path="alt_text_merged_dataset", repo_id="",token=None, dataset_split="train"):
|
||||
"""
|
||||
Example of how to push dataset to Hugging Face Hub.
|
||||
You need to authenticate first!
|
||||
"""
|
||||
from huggingface_hub import login
|
||||
|
||||
print("\n=== Pushing Dataset to Hugging Face Hub ===")
|
||||
# Method 1: Login interactively (will prompt for token)
|
||||
# login()
|
||||
|
||||
# Method 2: Login with token directly
|
||||
login(token=token)
|
||||
|
||||
# Method 3: Set token as environment variable
|
||||
# export HF_TOKEN="hf_YourTokenHere"
|
||||
# Then login() will use it automatically
|
||||
|
||||
# Load your dataset
|
||||
ds = load_dataset_from_disk(dataset_path)
|
||||
|
||||
# Combine into DatasetDict
|
||||
if dataset_split == "train":
|
||||
ds = DatasetDict(
|
||||
{
|
||||
"train": ds,
|
||||
#"test": test_dataset
|
||||
}
|
||||
)
|
||||
elif dataset_split == "test":
|
||||
ds = DatasetDict(
|
||||
{
|
||||
#"train": train_dataset,
|
||||
"test": ds,
|
||||
}
|
||||
)
|
||||
elif dataset_split == "validation":
|
||||
ds = DatasetDict(
|
||||
{
|
||||
#"train": train_dataset,
|
||||
"validation": ds,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid dataset_split: {dataset_split}")
|
||||
|
||||
# Push to hub (creates repo if it doesn't exist)
|
||||
ds.push_to_hub( # Automatically converts to Parquet when uploading to Hub
|
||||
repo_id, # Replace with your username
|
||||
private=False, # Set True for private dataset
|
||||
)
|
||||
|
||||
print("Dataset pushed successfully!")
|
||||
print(f"View at: https://huggingface.co/datasets/{repo_id}")
|
||||
|
||||
|
||||
def create_dataset_from_json(json_filepath, json_filepath_images, images_dir="images"):
|
||||
"""
|
||||
Create a Hugging Face Dataset from JSON file with local images.
|
||||
|
||||
Args:
|
||||
json_filepath: Path to JSON file with your data structure
|
||||
images_dir: Directory containing the images (default: "images")
|
||||
|
||||
Returns:
|
||||
datasets.Dataset object with images loaded
|
||||
"""
|
||||
with open(json_filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
with open(json_filepath_images, "r", encoding="utf-8") as f:
|
||||
data_images = json.load(f)
|
||||
|
||||
images_path = Path(images_dir)
|
||||
|
||||
# Flatten the nested structure and load images
|
||||
flattened_data = {
|
||||
"image": [],
|
||||
"image_url": [],
|
||||
"original_alt_text": [],
|
||||
"llm_assessment": [],
|
||||
"llm_judgment": [],
|
||||
"llm_evaluation_result": [],
|
||||
"llm_alt_text": [],
|
||||
"page_url": [],
|
||||
"html_context": [],
|
||||
"page_title": [],
|
||||
"page_description": [],
|
||||
"page_keywords": []
|
||||
}
|
||||
|
||||
count_entry = 0
|
||||
for entry in data:
|
||||
if (
|
||||
entry["mllm_response"]["original_alt_text_assessment"] is None
|
||||
): # important! skip entries with no MLLM response. not usable data
|
||||
print(
|
||||
f"Skipping entry with image URL: {entry['image_url']} due to missing MLLM response"
|
||||
)
|
||||
count_entry += 1
|
||||
continue # Skip entries with no MLLM response
|
||||
image_url = entry["image_url"]
|
||||
image_filename = url_to_filename(image_url)
|
||||
image_path = images_path / image_filename
|
||||
|
||||
# Load image if it exists
|
||||
if image_path.exists():
|
||||
img = Image.open(image_path)
|
||||
flattened_data["image"].append(img)
|
||||
else:
|
||||
print(f"Warning: Image not found: {image_path}")
|
||||
flattened_data["image"].append(None)
|
||||
|
||||
flattened_data["image_url"].append(image_url)
|
||||
flattened_data["original_alt_text"].append(entry["alt_text"])
|
||||
flattened_data["llm_assessment"].append(
|
||||
str(entry["mllm_response"]["original_alt_text_assessment"])
|
||||
)
|
||||
flattened_data["llm_judgment"].append(entry["mllm_response"]["assessment"])
|
||||
flattened_data["llm_evaluation_result"].append(
|
||||
entry["mllm_response"]["evaluation_result"]
|
||||
)
|
||||
flattened_data["llm_alt_text"].append(entry["mllm_response"]["new_alt_text"])
|
||||
flattened_data["page_url"].append(data_images[count_entry]["page_url"])
|
||||
flattened_data["html_context"].append(data_images[count_entry]["html_context"])
|
||||
flattened_data["page_title"].append(data_images[count_entry]["page_title"])
|
||||
flattened_data["page_description"].append(data_images[count_entry]["page_description"])
|
||||
flattened_data["page_keywords"].append(data_images[count_entry]["page_keywords"])
|
||||
|
||||
count_entry += 1
|
||||
|
||||
print(f"Total valid entries loaded: {len(flattened_data['image_url'])}")
|
||||
return datasets.Dataset.from_dict(flattened_data)
|
||||
|
||||
|
||||
def create_dataset_from_folders(
|
||||
ref_path,
|
||||
json_filename="mllm_alttext_assessments.json",
|
||||
json_filename_images="extracted_images.json",
|
||||
images_dirname="images",
|
||||
):
|
||||
"""
|
||||
Create a merged dataset from multiple folders under ref_path.
|
||||
Each folder should contain a JSON file and an images subdirectory.
|
||||
|
||||
Args:
|
||||
ref_path: Root path containing multiple folders
|
||||
json_filename: Name of JSON file in each folder (default: "data.json")
|
||||
images_dirname: Name of images subdirectory (default: "images")
|
||||
|
||||
Returns:
|
||||
datasets.Dataset object with all entries merged
|
||||
"""
|
||||
ref_path = Path(ref_path)
|
||||
all_datasets = []
|
||||
|
||||
# Find all subdirectories containing the JSON file
|
||||
folders_processed = 0
|
||||
|
||||
for folder in ref_path.iterdir():
|
||||
if not folder.is_dir():
|
||||
continue
|
||||
|
||||
json_path = folder / json_filename
|
||||
json_path_images = folder / json_filename_images
|
||||
images_path = folder / images_dirname
|
||||
|
||||
# Check if both JSON and images directory exist
|
||||
if not json_path.exists():
|
||||
print(f"Skipping {folder.name}: no {json_filename} found")
|
||||
continue
|
||||
|
||||
if not json_path_images.exists():
|
||||
print(f"Skipping {folder.name}: no {json_filename_images} found")
|
||||
continue
|
||||
|
||||
if not images_path.exists():
|
||||
print(f"Warning: {folder.name}: images directory not found")
|
||||
# continue
|
||||
# Continue anyway, images might be optional (from urls only)
|
||||
|
||||
print(f"Processing folder: {folder.name}")
|
||||
|
||||
try:
|
||||
# Create dataset for this folder
|
||||
ds = create_dataset_from_json(
|
||||
str(json_path), str(json_path_images), str(images_path)
|
||||
)
|
||||
all_datasets.append(ds)
|
||||
|
||||
folders_processed += 1
|
||||
print(f" -> Loaded {len(ds)} entries")
|
||||
except Exception as e:
|
||||
print(f"Error processing {folder.name}: {e}")
|
||||
continue
|
||||
|
||||
if not all_datasets:
|
||||
raise ValueError(f"No valid folders found in {ref_path}")
|
||||
|
||||
# Merge all datasets
|
||||
print(f"\n=== Merging {folders_processed} folders ===")
|
||||
merged_dataset = datasets.concatenate_datasets(all_datasets)
|
||||
print(f"Total entries: {len(merged_dataset)}")
|
||||
|
||||
return merged_dataset
|
||||
|
||||
|
||||
def verify_images(json_filepath, images_dir="images"):
|
||||
"""
|
||||
Verify that all images referenced in JSON exist in the images directory.
|
||||
|
||||
Args:
|
||||
json_filepath: Path to JSON file
|
||||
images_dir: Directory containing images
|
||||
|
||||
Returns:
|
||||
Dict with 'found', 'missing', and 'details' keys
|
||||
"""
|
||||
with open(json_filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
images_path = Path(images_dir)
|
||||
|
||||
found = []
|
||||
missing = []
|
||||
|
||||
for entry in data:
|
||||
image_url = entry["image_url"]
|
||||
image_filename = url_to_filename(image_url)
|
||||
image_path = images_path / image_filename
|
||||
print(
|
||||
"image_url:",
|
||||
image_url,
|
||||
"image_filename:",
|
||||
image_filename,
|
||||
"image_path:",
|
||||
image_path,
|
||||
)
|
||||
|
||||
if image_path.exists():
|
||||
found.append(
|
||||
{"url": image_url, "filename": image_filename, "path": str(image_path)}
|
||||
)
|
||||
else:
|
||||
missing.append(
|
||||
{
|
||||
"url": image_url,
|
||||
"filename": image_filename,
|
||||
"expected_path": str(image_path),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"found": len(found),
|
||||
"missing": len(missing),
|
||||
"total": len(data),
|
||||
"details": {"found_images": found, "missing_images": missing},
|
||||
}
|
||||
|
||||
|
||||
def verify_images_in_folders(
|
||||
ref_path, json_filename="mllm_alttext_assessments.json", images_dirname="images"
|
||||
):
|
||||
"""
|
||||
Verify images across all folders under ref_path.
|
||||
|
||||
Args:
|
||||
ref_path: Root path containing multiple folders
|
||||
json_filename: Name of JSON file in each folder
|
||||
images_dirname: Name of images subdirectory
|
||||
|
||||
Returns:
|
||||
Dict with aggregated verification results
|
||||
"""
|
||||
ref_path = Path(ref_path)
|
||||
total_found = 0
|
||||
total_missing = 0
|
||||
total_entries = 0
|
||||
folder_results = {}
|
||||
|
||||
for folder in ref_path.iterdir():
|
||||
if not folder.is_dir():
|
||||
continue
|
||||
|
||||
json_path = folder / json_filename
|
||||
images_path = folder / images_dirname
|
||||
|
||||
if not json_path.exists():
|
||||
continue
|
||||
|
||||
print(f"Verifying folder: {folder.name}")
|
||||
|
||||
try:
|
||||
verification = verify_images(str(json_path), str(images_path))
|
||||
folder_results[folder.name] = verification
|
||||
|
||||
total_found += verification["found"]
|
||||
total_missing += verification["missing"]
|
||||
total_entries += verification["total"]
|
||||
|
||||
print(f" Found: {verification['found']}/{verification['total']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
"found": total_found,
|
||||
"missing": total_missing,
|
||||
"total": total_entries,
|
||||
"folders": folder_results,
|
||||
}
|
||||
|
||||
|
||||
def save_dataset(dataset, output_path):
|
||||
"""Save dataset in Arrow format (includes images)."""
|
||||
dataset.save_to_disk(output_path)
|
||||
# print(f"Dataset saved to {output_path}")
|
||||
|
||||
# Or save as JSON
|
||||
# dataset.to_json(f"{output_path}/data.json")
|
||||
|
||||
# Or save as CSV
|
||||
# dataset.to_csv(f"{output_path}/data.csv")
|
||||
|
||||
# Or save as Parquet
|
||||
# dataset.to_parquet(f"{output_path}/data.parquet")
|
||||
|
||||
|
||||
def load_dataset_from_disk(dataset_path):
|
||||
"""Load a previously saved dataset."""
|
||||
return datasets.load_from_disk(dataset_path)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EXAMPLE USAGE
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--ref_path",
|
||||
type=str,
|
||||
help=("Root path containing multiple folders"),
|
||||
default="C:\\cartella_condivisa\\MachineLearning\\HIISlab\\accessibility\\notebook_miei\\LLM_accessibility_validator\\out",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=("If True push the merged dataset to Hugging Face Hub"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
help=("Hugging Face authentication token"),
|
||||
default="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",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
help=("Hugging Face repository ID"),
|
||||
default="nicolaleo/LLM-alt-text-assessment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_split",
|
||||
type=str,
|
||||
help=("dataset split type: train, test, validation"),
|
||||
default="train",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
help=("dataset name to save/load"),
|
||||
default="alt_text_merged_dataset_full_features",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example 1: Verify images across all folders
|
||||
print("=== Verifying Images in All Folders ===")
|
||||
verification = verify_images_in_folders(args.ref_path)
|
||||
print("\n######## Verifier output ################################")
|
||||
print(f"Total Found: {verification['found']}/{verification['total']}")
|
||||
print(f"Total Missing: {verification['missing']}/{verification['total']}")
|
||||
print("########################################")
|
||||
|
||||
# Show per-folder breakdown
|
||||
print("\n=== Per-Folder Breakdown ===")
|
||||
for folder_name, results in verification["folders"].items():
|
||||
print(f"{folder_name}: {results['found']}/{results['total']} images found")
|
||||
|
||||
# Example 2: Create merged dataset from all folders
|
||||
print("\n=== Creating Merged Dataset ===")
|
||||
ds = create_dataset_from_folders(args.ref_path)
|
||||
print("\n######## Merged Dataset output ################################")
|
||||
print(f"Final dataset size: {len(ds)} entries")
|
||||
print("########################################")
|
||||
|
||||
# Example 3: Analyze the merged dataset
|
||||
print("\n=== Dataset Analysis ===")
|
||||
print(ds)
|
||||
|
||||
# Example 3: Access images and data
|
||||
print("\n=== First Example ===")
|
||||
first_example = ds[0]
|
||||
print(f"Image URL: {first_example['image_url']}")
|
||||
print(f"Original Alt text: {first_example['original_alt_text']}")
|
||||
print(f"LLM judgment: {first_example['llm_judgment']}")
|
||||
print(f"LLM alt text: {first_example['llm_alt_text']}")
|
||||
print(f"Image loaded: {first_example['image'] is not None}")
|
||||
|
||||
if first_example["image"] is not None:
|
||||
img = first_example["image"]
|
||||
print(f"Image size: {img.size}")
|
||||
# img.show() # Uncomment to display image
|
||||
|
||||
# Example 4: Filter and work with merged data
|
||||
print("\n=== Filtering Merged Dataset ===")
|
||||
successful = ds.filter(lambda x: x["llm_judgment"] == "success")
|
||||
print(f"Successful assessments: {len(successful)}")
|
||||
|
||||
high_rated = ds.filter(lambda x: int(x["llm_assessment"]) >= 4)
|
||||
print(f"High-rated (>=4): {len(high_rated)}")
|
||||
|
||||
# Example 5: Save merged dataset
|
||||
print("\n=== Saving Merged Dataset ===")
|
||||
save_dataset(ds, args.dataset_name)
|
||||
|
||||
# Example 6: Load dataset
|
||||
print("\n=== Loading Dataset ===")
|
||||
loaded_ds = load_dataset_from_disk(args.dataset_name)
|
||||
print(f"Loaded {len(loaded_ds)} entries")
|
||||
|
||||
if args.push_to_hub:
|
||||
# Push to Hugging Face Hub (optional)
|
||||
push_to_hub_example(dataset_path=args.dataset_name, repo_id=args.repo_id, token=args.token, dataset_split=args.dataset_split)
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -7,6 +7,9 @@ import os
|
|||
import requests
|
||||
import base64
|
||||
import re
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,7 +57,7 @@ def call_API_urlibrequest(
|
|||
|
||||
# Send the request and capture the response
|
||||
|
||||
with urllib.request.urlopen(request) as response:
|
||||
with urllib.request.urlopen(request, timeout=300) as response:
|
||||
# Read and decode the response
|
||||
|
||||
response_json = json.loads(response.read().decode("utf-8"))
|
||||
|
|
@ -158,9 +161,24 @@ def parse_mllm_alt_text_response(mllm_response):
|
|||
|
||||
|
||||
|
||||
|
||||
def encode_image_from_url(image_url):
|
||||
response = requests.get(image_url)
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
# Open image and convert to RGB
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
|
||||
# Convert to RGB (handles RGBA, grayscale, etc.)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Save to bytes buffer
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='PNG') # or 'JPEG'
|
||||
buffer.seek(0)
|
||||
|
||||
# Encode to base64
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,341 @@
|
|||
from huggingface_hub import login
|
||||
import os
|
||||
import gc
|
||||
|
||||
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# Login into Hugging Face Hub
|
||||
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
||||
print("Logging into Hugging Face Hub...")
|
||||
login(hf_token)
|
||||
print("Logged in.")
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
# System message for the assistant
|
||||
system_message = "You are an expert product description writer for Amazon."
|
||||
print("System message set.")
|
||||
|
||||
# User prompt that combines the user query and the schema
|
||||
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
|
||||
Only return description. The description should be SEO optimized and for a better mobile search experience.
|
||||
|
||||
<PRODUCT>
|
||||
{product}
|
||||
</PRODUCT>
|
||||
|
||||
<CATEGORY>
|
||||
{category}
|
||||
</CATEGORY>
|
||||
"""
|
||||
|
||||
# Convert dataset to OAI messages
|
||||
def format_data(sample):
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],#esempio unsloth non ha system message
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt.format(
|
||||
product=sample["Product Name"],
|
||||
category=sample["Category"],
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": sample["image"],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": sample["description"]}],#vedi ruolo assistente per la risposta aspettata
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
# Iterate through each conversation
|
||||
for msg in messages:
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
image_inputs.append(image.convert("RGB"))#converte in rgb !
|
||||
return image_inputs
|
||||
|
||||
print("Loading dataset...")
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train",cache_dir="./dataset_cache")
|
||||
|
||||
# Convert dataset to OAI messages
|
||||
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
|
||||
dataset = [format_data(sample) for sample in dataset]
|
||||
|
||||
print(dataset[345]["messages"])
|
||||
|
||||
import torch
|
||||
torch.cuda.get_device_capability()
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Get free memory in bytes
|
||||
free_memory = torch.cuda.mem_get_info()[0]
|
||||
total_memory = torch.cuda.mem_get_info()[1]
|
||||
|
||||
# Convert to GB for readability
|
||||
free_gb = free_memory / (1024**3)
|
||||
total_gb = total_memory / (1024**3)
|
||||
|
||||
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
|
||||
# Hugging Face model id
|
||||
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
|
||||
|
||||
# Check if GPU benefits from bfloat16
|
||||
#if torch.cuda.get_device_capability()[0] < 8:
|
||||
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
|
||||
|
||||
# Define model init arguments
|
||||
model_kwargs = dict(
|
||||
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
||||
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
||||
device_map="auto", # Let torch decide how to load the model
|
||||
|
||||
)
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
||||
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
||||
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
||||
|
||||
|
||||
|
||||
|
||||
# Set the cache directory to current folder
|
||||
cache_dir = "./model_cache" # or just "." for directly in current folder
|
||||
|
||||
print("Loading model... This may take a while.")
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
cache_dir=cache_dir,
|
||||
**model_kwargs
|
||||
)
|
||||
print("Model loaded.")
|
||||
|
||||
|
||||
proc_cache_dir = "./proc_cache"
|
||||
print("Loading processor...")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
||||
cache_dir=proc_cache_dir
|
||||
)
|
||||
print("Processor loaded.")
|
||||
|
||||
# Download and save to current folder
|
||||
print("Saving model and processor locally...")
|
||||
save_path = "./local_model"
|
||||
#model.save_pretrained(save_path)
|
||||
#processor.save_pretrained(save_path)
|
||||
print("Model and processor saved.")
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
#modules_to_save=[ #quello che mi prendeva
|
||||
# "lm_head",
|
||||
# "embed_tokens",
|
||||
#],
|
||||
)
|
||||
|
||||
from trl import SFTConfig
|
||||
|
||||
args = SFTConfig(
|
||||
output_dir="./gemma-finetuned", # directory to save and repository id
|
||||
num_train_epochs=1, # number of training epochs
|
||||
per_device_train_batch_size=1, # batch size per device during training
|
||||
gradient_accumulation_steps=8,#4, # number of steps before performing a backward/update pass
|
||||
gradient_checkpointing=True, # use gradient checkpointing to save memory
|
||||
optim="adamw_torch_fused", # use fused adamw optimizer
|
||||
logging_steps=5, # log every 5 steps
|
||||
save_strategy="epoch", # save checkpoint every epoch
|
||||
learning_rate=2e-4, # learning rate, based on QLoRA paper
|
||||
bf16=True,#False,#True, # use bfloat16 precision
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
||||
lr_scheduler_type="constant", # use constant learning rate scheduler
|
||||
push_to_hub=True, # push model to hub
|
||||
report_to="tensorboard", # report metrics to tensorboard
|
||||
gradient_checkpointing_kwargs={
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
dataset_text_field="", # need a dummy field for collator
|
||||
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
|
||||
)
|
||||
args.remove_unused_columns = False # important for collator
|
||||
|
||||
# Create a data collator to encode text and image pairs
|
||||
def collate_fn(examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
image_inputs = process_vision_info(example["messages"])
|
||||
text = processor.apply_chat_template(
|
||||
example["messages"], add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
texts.append(text.strip())
|
||||
images.append(image_inputs)
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
|
||||
# Mask image tokens
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.tokenizer.special_tokens_map["boi_token"]
|
||||
)
|
||||
]
|
||||
# Mask tokens for not being used in the loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
processing_class=processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
# Start training, the model will be automatically saved to the Hub and the output directory
|
||||
trainer.train()
|
||||
|
||||
# Save the final model again to the Hugging Face Hub
|
||||
trainer.save_model()
|
||||
|
||||
# free the memory again
|
||||
del model
|
||||
del trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from peft import PeftModel
|
||||
|
||||
# Load Model base model
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)
|
||||
|
||||
# Merge LoRA and base model and save
|
||||
peft_model = PeftModel.from_pretrained(model, args.output_dir)
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")
|
||||
|
||||
processor = AutoProcessor.from_pretrained(args.output_dir)
|
||||
processor.save_pretrained("merged_model")
|
||||
|
||||
import torch
|
||||
|
||||
# Load Model with PEFT adapter
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
args.output_dir,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(args.output_dir)
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
# Test sample with Product Name, Category and Image
|
||||
sample = {
|
||||
"product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
|
||||
"category": "Toys & Games | Toy Figures & Playsets | Action Figures",
|
||||
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
|
||||
}
|
||||
|
||||
|
||||
# NB: inferenza fatta con input immagine e i due campi testuali (e stessa instruction del finetuning)
|
||||
def generate_description(sample, model, processor):
|
||||
# Convert sample into messages and then apply the chat template
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image","image": sample["image"]},
|
||||
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
|
||||
]},
|
||||
]
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# Process the image and text
|
||||
image_inputs = process_vision_info(messages)# converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
|
||||
# Tokenize the text and process the images
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# Move the inputs to the device
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
# Generate the output
|
||||
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8, eos_token_id=stop_token_ids, disable_compile=True)
|
||||
# Trim the generation and decode the output to text
|
||||
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
return output_text[0]
|
||||
|
||||
# generate the description
|
||||
|
||||
description = generate_description(sample, model, processor)
|
||||
print(description)
|
||||
|
|
@ -1,3 +1,10 @@
|
|||
# link di riferimneto https://ai.google.dev/gemma/docs/core/ (https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) # usa lora
|
||||
# lanciato su macchina GPU CNR
|
||||
|
||||
# altri riferimenti:
|
||||
#https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune . Ancora supervised tuning no GRPO (ma full non lora). Interessante che mette eval_dataset in SFTTrainer per vedere curve di training (train e eval loss)
|
||||
# https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl GRPO fine tuning (ma non specifico per gemma3 e non con input immagine)
|
||||
|
||||
from huggingface_hub import login
|
||||
import os
|
||||
import gc
|
||||
|
|
@ -0,0 +1,759 @@
|
|||
# link di riferimneto https://ai.google.dev/gemma/docs/core/ (https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) # usa lora
|
||||
# lanciato su macchina GPU CNR
|
||||
|
||||
# altri riferimenti:
|
||||
#https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune . Ancora supervised tuning no GRPO (ma full non lora). Interessante che mette eval_dataset in SFTTrainer per vedere curve di training (train e eval loss)
|
||||
# https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl GRPO fine tuning (ma non specifico per gemma3 e non con input immagine)
|
||||
# veder il notebook unsloth su colab che applica GRPO
|
||||
|
||||
from huggingface_hub import login
|
||||
import os
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
|
||||
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# Login into Hugging Face Hub
|
||||
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
||||
print("Logging into Hugging Face Hub...")
|
||||
login(hf_token)
|
||||
print("Logged in.")
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
# System message for the assistant
|
||||
|
||||
system_message = """You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for
|
||||
images on webpages are appropriate according to WCAG guidelines. The alt-text should serve the same purpose and present
|
||||
the same information as the original image content. As a result, it is possible to remove the image content and replace it with the text alternative and no functionality or information would be lost. This text alternative should not necessarily describe the image content.
|
||||
It should serve the same purpose and convey the same information. This may sometimes result in a text alternative that looks like a description of the image content. But this would only be true if that was the best way to serve the same purpose.
|
||||
If possible, the short text alternative should completely convey the purpose and information. If it is not possible to do this in a short phrase or sentence, then the short text alternative should provide a brief overview of the information.
|
||||
The text alternative should be able to substitute for the image content. If the image content were removed from the page and substituted with the text, the page would still provide the same function and information. The text alternative would be brief but as informative as possible.
|
||||
In deciding what text to include in the alternative, it is often a good idea to consider the following questions:
|
||||
Why is this image content here?
|
||||
What information is it presenting?
|
||||
What purpose does it fulfill?
|
||||
If I could not use the image content, what words would I use to convey the same function and/or information?
|
||||
|
||||
When image content contains words that are important to understanding the content, the alt text should include those words.
|
||||
Decorative images don’t add information to the content of a page. For example, the information provided by the image might already be given using adjacent text, or the image might be included to make the website more visually attractive.
|
||||
In these cases, a null (empty) alt text should be provided (alt="") so that they can be ignored by assistive technologies, such as screen readers.
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
- The image found on the webpage.
|
||||
- The associated alternative text. When the alt-text is empty or absent, you will be explicitly informed.
|
||||
- The surrounding context of the image.
|
||||
- The page title, headings and the content of the “keywords” and “description” <meta> tag, if found.
|
||||
|
||||
2. Determine the function and purpose of the image by analyzing these elements. Take into account the purpose and function
|
||||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
where the level of certainty goes from 1 to 100 and 'sufficient certainty' means > 80
|
||||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
|
||||
|
||||
def parse_mllm_alt_text_response(mllm_response): #quella dentro utils_API
|
||||
"""
|
||||
Parse an MLLM response string and extract key attributes into a JSON object.
|
||||
|
||||
from mllm response like:
|
||||
```json\n{\n\"Original alt-text assessment\"... etc
|
||||
to a structured dictionary.
|
||||
|
||||
Args:
|
||||
mllm_response (str): The raw MLLM response text containing JSON data
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the extracted attributes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Handle NaN or None values
|
||||
if mllm_response is None or mllm_response == "":
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
# Extract JSON content between ```json and ``` markers
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
# Try to find JSON without markdown code blocks
|
||||
json_match = re.search(r'\{.*\}', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
print("No JSON match found in MLLM response.")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
json_str = json_match.group(1) if '```json' in mllm_response else json_match.group(0)
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Create a structured output with the key attributes
|
||||
result = {
|
||||
"original_alt_text_assessment": parsed_data.get("Original alt-text assessment", ""),
|
||||
"assessment": parsed_data.get("Assessment", ""),
|
||||
"evaluation_result": parsed_data.get("EvaluationResult", ""),
|
||||
"new_alt_text": parsed_data.get("New alt-text", "")
|
||||
}
|
||||
print("Parsed MLLM response:", result)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing MLLM response: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
def format_user_prompt(
|
||||
original_alt_text,
|
||||
html_context,
|
||||
page_title,
|
||||
page_description,
|
||||
page_keywords,
|
||||
):
|
||||
|
||||
alt_text = "Here is the alt-text of the image: " + str(original_alt_text)
|
||||
|
||||
HTML_context = "Here is the surrounding HTML context of the element: " + str(html_context)
|
||||
|
||||
page_text = "Here is the content of the page: Title of the page: " + str(page_title)
|
||||
|
||||
page_text = page_text + ", content of the <meta name='description'> tag: "+ str(page_description)
|
||||
|
||||
page_text = page_text+ ", content of the <meta name='keywords'> tag: "+ str(page_keywords)
|
||||
|
||||
user_prompt_to_use=alt_text + " " + HTML_context + " " + page_text
|
||||
return user_prompt_to_use
|
||||
|
||||
def download_hf_model(model_id, output_dir="./hf_model"):
|
||||
"""Download model from Hugging Face"""
|
||||
print(f"Downloading {model_id} from Hugging Face...")
|
||||
model_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
local_dir=output_dir,
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
print(f"Model downloaded to: {model_path}")
|
||||
return model_path
|
||||
|
||||
def convert_to_gguf(model_path, output_path="./model.gguf"):
|
||||
"""
|
||||
Convert model to GGUF format using llama.cpp
|
||||
|
||||
Note: You need llama.cpp installed and convert.py script
|
||||
Clone from: https://github.com/ggerganov/llama.cpp
|
||||
"""
|
||||
print("Converting to GGUF format...")
|
||||
|
||||
# This assumes you have llama.cpp cloned and convert.py available
|
||||
# Adjust the path to your llama.cpp installation
|
||||
convert_script = "./llama.cpp/convert_hf_to_gguf.py" # Path to llama.cpp convert.py
|
||||
|
||||
cmd = [
|
||||
"python", convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16" # Use f16 for better quality, q4_0 for smaller size
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"GGUF model created: {output_path}")
|
||||
except FileNotFoundError:
|
||||
print("Error: llama.cpp convert.py not found.")
|
||||
print("Please clone llama.cpp: git clone https://github.com/ggerganov/llama.cpp")
|
||||
return None
|
||||
|
||||
return output_path
|
||||
|
||||
def create_modelfile(model_name, gguf_path, template=None):
|
||||
"""Create Ollama Modelfile"""
|
||||
modelfile_content = f"""FROM {gguf_path}
|
||||
|
||||
# Set parameters
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER top_p 0.9
|
||||
PARAMETER top_k 40
|
||||
|
||||
# Set the prompt template (adjust based on your model)
|
||||
TEMPLATE """
|
||||
|
||||
if template:
|
||||
modelfile_content += f'"""{template}"""'
|
||||
else:
|
||||
# Default template for chat models
|
||||
modelfile_content += '''"""{{ if .System }}System: {{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
{{ end }}Assistant: """'''
|
||||
|
||||
modelfile_path = model_name + "Modelfile"
|
||||
with open(modelfile_path, "w") as f:
|
||||
f.write(modelfile_content)
|
||||
|
||||
print(f"Modelfile created: {modelfile_path}")
|
||||
return modelfile_path
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def prepare_model_inputs(sample):
|
||||
# print("Preparing model inputs...")
|
||||
formatted_data = format_data(sample, is_inference=True)#in inferenza non gli passo la parte assistente (la risposta attesa)
|
||||
messages = formatted_data["messages"]
|
||||
image_inputs = process_vision_info(messages)
|
||||
return messages, image_inputs
|
||||
|
||||
def generate_description(sample, model, processor):
|
||||
print("Generating description...")
|
||||
# Prepare the model inputs
|
||||
messages, image_inputs = prepare_model_inputs(
|
||||
sample
|
||||
) # can be avoided if already prepared
|
||||
# Apply the chat template to get the final text input
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs, # PIL Image or list of images
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=8192, # Equivalent to num_ctx, max input token
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
# Generate the output
|
||||
stop_token_ids = [
|
||||
processor.tokenizer.eos_token_id,
|
||||
processor.tokenizer.convert_tokens_to_ids("<end_of_turn>"),
|
||||
]
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.7, # Same as Ollama
|
||||
"max_new_tokens": 800, # Equivalent to num_predict
|
||||
"top_p": 0.95, # Same as Ollama
|
||||
"do_sample": True, # Required for temperature/top_p to work
|
||||
}
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
**generation_config,
|
||||
eos_token_id=stop_token_ids,
|
||||
disable_compile=True,
|
||||
)
|
||||
# Trim the generation and decode the output to text
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
print("Raw model output text:", type(output_text[0]), output_text[0])
|
||||
parsed_resp = parse_mllm_alt_text_response(output_text[0])
|
||||
parsed_resp["model_id"]="gemma3_4b"
|
||||
return parsed_resp
|
||||
|
||||
|
||||
|
||||
def format_data(sample, is_inference=False):
|
||||
|
||||
formatted_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format_user_prompt( original_alt_text=sample["original_alt_text"],
|
||||
html_context=sample["html_context"],
|
||||
page_title=sample["page_title"],
|
||||
page_description=sample["page_description"],
|
||||
page_keywords=sample["page_keywords"],
|
||||
|
||||
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": sample["image"].convert("RGB"), # nb: unico diverso che in inferenza che usa request. qua uso direttamente l'immagine PIL
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
if is_inference:
|
||||
# print(
|
||||
# "formatted_data for inference:", formatted_data
|
||||
# ) # non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
||||
pass
|
||||
|
||||
else:
|
||||
formatted_data["messages"].append( #aggiungo la parte di risposta attesa per il training
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
|
||||
#{"type": "text", "text": sample["llm_alt_text"]} , #only alt-text atteso
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"""```json
|
||||
{{
|
||||
"Original alt-text assessment": {json.dumps(sample["llm_assessment"])},
|
||||
"Assessment": {json.dumps(sample["llm_judgment"])},
|
||||
"EvaluationResult": {json.dumps(sample["llm_evaluation_result"])},
|
||||
"New alt-text": {json.dumps(sample["llm_alt_text"])}
|
||||
}}
|
||||
```"""
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
)
|
||||
return formatted_data
|
||||
|
||||
def process_vision_info_old(messages: list[dict]) -> list[Image.Image]:
|
||||
print("Processing vision info...")
|
||||
image_inputs = []
|
||||
# Iterate through each conversation
|
||||
for msg in messages:
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
image_inputs.append(image.convert("RGB"))#converte in rgb !
|
||||
return image_inputs
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
print("Processing vision info...")
|
||||
image_inputs = []
|
||||
image_index = 0
|
||||
|
||||
# Iterate through each conversation
|
||||
for msg_index, msg in enumerate(messages):
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element_index, element in enumerate(content):
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
try:
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element.get("image")
|
||||
|
||||
# Convert to RGB if it's a PIL Image
|
||||
if isinstance(image, Image.Image):
|
||||
print(f"Image {image_index} - Original shape: {image.size}, mode: {image.mode}")
|
||||
|
||||
# Check for problematic dimensions
|
||||
if image.size[0] <= 1 or image.size[1] <= 1:
|
||||
print(f"⚠️ WARNING: Image {image_index} has very small dimensions: {image.size}. Skipping or resizing...")
|
||||
# Option 1: Skip the image
|
||||
#continue
|
||||
# Option 2: Resize to minimum viable size
|
||||
image = image.resize((224, 224))
|
||||
print(f"Resized image {image_index} to: {image.size}")
|
||||
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
print(f"Converted image {image_index} to RGB mode")
|
||||
|
||||
image_inputs.append(image)
|
||||
image_index += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR processing image {image_index} at message[{msg_index}].content[{element_index}]")
|
||||
print(f" Error details: {type(e).__name__}: {e}")
|
||||
if isinstance(image, Image.Image):
|
||||
print(f" Image properties - Size: {image.size}, Mode: {image.mode}")
|
||||
continue
|
||||
|
||||
print(f"Successfully processed {len(image_inputs)} images")
|
||||
return image_inputs
|
||||
|
||||
print("Loading dataset...")
|
||||
# Load dataset from the hub
|
||||
|
||||
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment-full-features", split="train",cache_dir="./dataset_cache")
|
||||
|
||||
dataset_validation = load_dataset("nicolaleo/LLM-alt-text-assessment-full-features_validation", split="validation",cache_dir="./dataset_cache")
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
dataset_copy=deepcopy(dataset)
|
||||
|
||||
|
||||
|
||||
# Convert dataset to OAI messages
|
||||
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
|
||||
dataset = [format_data(sample, is_inference=False) for sample in dataset]
|
||||
|
||||
dataset_validation = [format_data(sample, is_inference=False) for sample in dataset_validation]
|
||||
|
||||
|
||||
print(dataset[0]["messages"])
|
||||
print("parse_mllm_alt_text_response:",parse_mllm_alt_text_response(dataset[0]["messages"][2]["content"][0]["text"]))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
torch.cuda.get_device_capability()
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Get free memory in bytes
|
||||
free_memory = torch.cuda.mem_get_info()[0]
|
||||
total_memory = torch.cuda.mem_get_info()[1]
|
||||
|
||||
# Convert to GB for readability
|
||||
free_gb = free_memory / (1024**3)
|
||||
total_gb = total_memory / (1024**3)
|
||||
|
||||
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
|
||||
# Hugging Face model id
|
||||
model_id = "google/gemma-3-4b-it"#"google/gemma-3-4b-pt"#"google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
|
||||
|
||||
# Check if GPU benefits from bfloat16
|
||||
#if torch.cuda.get_device_capability()[0] < 8:
|
||||
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
|
||||
|
||||
# Define model init arguments
|
||||
model_kwargs = dict(
|
||||
attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
||||
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
||||
device_map="auto", # Let torch decide how to load the model
|
||||
|
||||
)
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
||||
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
||||
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
||||
|
||||
|
||||
|
||||
|
||||
# Set the cache directory to current folder
|
||||
cache_dir = "./model_cache" # or just "." for directly in current folder
|
||||
|
||||
print("Loading model... This may take a while.")
|
||||
model=AutoModelForImageTextToText.from_pretrained(# versione quantizzata 4bit-bf16
|
||||
model_id,
|
||||
cache_dir=cache_dir,
|
||||
**model_kwargs
|
||||
)
|
||||
print("Model loaded.")
|
||||
print(f"Original Model dtype: {model.dtype}") # ritorna torch.bfloat16
|
||||
|
||||
proc_cache_dir = "./proc_cache"
|
||||
print("Loading processor...")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id,#"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
||||
cache_dir=proc_cache_dir
|
||||
)
|
||||
print("Processor loaded.")
|
||||
|
||||
|
||||
print("testing the loaded model...")
|
||||
# generate the description
|
||||
description = generate_description(dataset_copy[0], model, processor)
|
||||
print("text generated:",description)
|
||||
|
||||
|
||||
# Download and save to current folder
|
||||
print("Saving model and processor locally...")
|
||||
save_path = "./original_local_model_"+model_id.replace("/", "_")
|
||||
model.save_pretrained(save_path)
|
||||
processor.save_pretrained(save_path)
|
||||
print("Model and processor saved.")
|
||||
|
||||
|
||||
""" # la convesrione in ollama funziona solo se fatta su modello non quantizzato (da capire se si può fare su modello 4bit)
|
||||
print("Converting and importing model to Ollama...")
|
||||
# Step 1: Download from Hugging Face
|
||||
model_path= "./original_local_model_ollama"
|
||||
model_path = download_hf_model(model_id,output_dir=model_path)
|
||||
|
||||
# Step 2: Convert to GGUF (requires llama.cpp)
|
||||
gguf_path = convert_to_gguf(model_path, "./gemma.gguf")
|
||||
|
||||
if gguf_path:
|
||||
# Step 3: Create Modelfile
|
||||
OLLAMA_MODEL_NAME = "gemma3-wcag"
|
||||
modelfile = create_modelfile(OLLAMA_MODEL_NAME, gguf_path)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=8,#16,
|
||||
lora_dropout=0.05,
|
||||
r=8,#16,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
#modules_to_save=[ #quello che mi prendeva memoria in più
|
||||
# "lm_head",
|
||||
# "embed_tokens",
|
||||
#],
|
||||
)
|
||||
|
||||
from trl import SFTConfig
|
||||
|
||||
args = SFTConfig(
|
||||
output_dir="./gemma-finetuned-wcag_"+model_id.replace("/", "_"), # directory to save and repository id
|
||||
num_train_epochs=1, # number of training epochs
|
||||
per_device_train_batch_size=1, # batch size per device during training
|
||||
per_device_eval_batch_size=1, # batch size for evaluation
|
||||
gradient_accumulation_steps=2,#4, # number of steps before performing a backward/update pass
|
||||
gradient_checkpointing=True, # use gradient checkpointing to save memory
|
||||
optim="adamw_8bit",#"adamw_torch_fused", # use fused adamw optimizer
|
||||
logging_steps=5, # log every 5 steps
|
||||
save_strategy="epoch", # save checkpoint every epoch
|
||||
#eval_strategy="epoch", # evaluate checkpoint every epoch
|
||||
eval_strategy="steps",
|
||||
eval_steps=5, # Evaluate every 5 step
|
||||
learning_rate=2e-4, # learning rate, based on QLoRA paper
|
||||
bf16=True,#False,#True, # use bfloat16 precision
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
||||
lr_scheduler_type="constant", # use constant learning rate scheduler
|
||||
push_to_hub=True, # push model to hub
|
||||
report_to="tensorboard", # report metrics to tensorboard
|
||||
gradient_checkpointing_kwargs={
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
dataset_text_field="", # need a dummy field for collator
|
||||
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
|
||||
)
|
||||
args.remove_unused_columns = False # important for collator
|
||||
|
||||
# Create a data collator to encode text and image pairs
|
||||
def collate_fn(examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
image_inputs = process_vision_info(example["messages"])
|
||||
text = processor.apply_chat_template(
|
||||
example["messages"], add_generation_prompt=False, tokenize=False
|
||||
)
|
||||
texts.append(text.strip())
|
||||
images.append(image_inputs)
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)#max_length=8192)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
|
||||
# Mask image tokens
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.tokenizer.special_tokens_map["boi_token"]
|
||||
)
|
||||
]
|
||||
# Mask tokens for not being used in the loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
|
||||
# Free CUDA memory between batches to reduce chance of OOM
|
||||
try:
|
||||
import torch as _torch
|
||||
if _torch.cuda.is_available():
|
||||
_torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
print("Error clearing CUDA cache, continuing without clearing.")
|
||||
#pass
|
||||
return batch
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=dataset_validation,
|
||||
peft_config=peft_config,
|
||||
processing_class=processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
# Start training, the model will be automatically saved to the Hub and the output directory
|
||||
trainer.train()
|
||||
|
||||
print("Training completed.")
|
||||
# Save the final model again to the Hugging Face Hub
|
||||
try:
|
||||
trainer.save_model()# lo salva (il Lora) su HF col nome specificato in args.output_dir
|
||||
print("Final model saved to Hugging Face Hub.")
|
||||
except Exception as e:
|
||||
print(f"Error saving model to Hugging Face Hub: {e}")
|
||||
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
# Access the log history
|
||||
log_history = trainer.state.log_history
|
||||
|
||||
# Extract training / validation loss
|
||||
train_losses = [log["loss"] for log in log_history if "loss" in log]
|
||||
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
|
||||
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
|
||||
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]
|
||||
|
||||
# Plot the training loss
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
|
||||
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss")
|
||||
plt.title("Training and Validation Loss per Epoch")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig("training_validation_loss.png", dpi=300, bbox_inches='tight')
|
||||
print("Plot saved successfully as 'training_validation_loss.png'")
|
||||
except Exception as e:
|
||||
print(f"Error plotting loss curves: {e}")
|
||||
|
||||
|
||||
# free the memory again
|
||||
del model
|
||||
del trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from peft import PeftModel
|
||||
|
||||
# Load Model base model
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, cache_dir=cache_dir) # versione bf16 (non quantizzata 4bit) NB: quantization happens at loading time,
|
||||
|
||||
print(f"Original Model dtype pre merge: {model.dtype}")
|
||||
|
||||
|
||||
|
||||
# Merge LoRA and base model and save
|
||||
peft_model = PeftModel.from_pretrained(model, args.output_dir)
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
merged_model.save_pretrained("merged_model_"+model_id.replace("/", "_"), safe_serialization=True, max_shard_size="2GB")
|
||||
|
||||
processor = AutoProcessor.from_pretrained(args.output_dir)
|
||||
processor.save_pretrained("merged_model_"+model_id.replace("/", "_"))
|
||||
|
||||
|
||||
print("Loading merged model for inference...")
|
||||
|
||||
merged_model_path = "merged_model_"+model_id.replace("/", "_")
|
||||
|
||||
# Load the merged model
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
merged_model_path,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
#attn_implementation="eager",
|
||||
)
|
||||
|
||||
print(f"Merged Model dtype: {model.dtype}")
|
||||
processor = AutoProcessor.from_pretrained(merged_model_path)
|
||||
|
||||
|
||||
print("testing the merged model...")
|
||||
|
||||
|
||||
# generate the description
|
||||
description = generate_description(dataset_copy[0], model, processor)
|
||||
print("text generated:",description)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,745 @@
|
|||
# link di riferimneto https://ai.google.dev/gemma/docs/core/ (https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) # usa lora
|
||||
# lanciato su macchina GPU CNR
|
||||
|
||||
# altri riferimenti:
|
||||
#https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune . Ancora supervised tuning no GRPO (ma full non lora). Interessante che mette eval_dataset in SFTTrainer per vedere curve di training (train e eval loss)
|
||||
# https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl GRPO fine tuning (ma non specifico per gemma3 e non con input immagine)
|
||||
# veder il notebook unsloth su colab che applica GRPO
|
||||
|
||||
from huggingface_hub import login
|
||||
import os
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
|
||||
os.environ['HF_HOME'] = './cache_huggingface' # or just "." for directly in current folder
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# Login into Hugging Face Hub
|
||||
hf_token = "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"#userdata.get('gemma3') # If you are running inside a Google Colab
|
||||
print("Logging into Hugging Face Hub...")
|
||||
login(hf_token)
|
||||
print("Logged in.")
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
|
||||
# System message for the assistant
|
||||
|
||||
system_message = """You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for
|
||||
images on webpages are appropriate according to WCAG guidelines. The alt-text should serve the same purpose and present
|
||||
the same information as the original image content. As a result, it is possible to remove the image content and replace it with the text alternative and no functionality or information would be lost. This text alternative should not necessarily describe the image content.
|
||||
It should serve the same purpose and convey the same information. This may sometimes result in a text alternative that looks like a description of the image content. But this would only be true if that was the best way to serve the same purpose.
|
||||
If possible, the short text alternative should completely convey the purpose and information. If it is not possible to do this in a short phrase or sentence, then the short text alternative should provide a brief overview of the information.
|
||||
The text alternative should be able to substitute for the image content. If the image content were removed from the page and substituted with the text, the page would still provide the same function and information. The text alternative would be brief but as informative as possible.
|
||||
In deciding what text to include in the alternative, it is often a good idea to consider the following questions:
|
||||
Why is this image content here?
|
||||
What information is it presenting?
|
||||
What purpose does it fulfill?
|
||||
If I could not use the image content, what words would I use to convey the same function and/or information?
|
||||
|
||||
When image content contains words that are important to understanding the content, the alt text should include those words.
|
||||
Decorative images don’t add information to the content of a page. For example, the information provided by the image might already be given using adjacent text, or the image might be included to make the website more visually attractive.
|
||||
In these cases, a null (empty) alt text should be provided (alt="") so that they can be ignored by assistive technologies, such as screen readers.
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
- The image found on the webpage.
|
||||
- The associated alternative text. When the alt-text is empty or absent, you will be explicitly informed.
|
||||
- The surrounding context of the image.
|
||||
- The page title, headings and the content of the “keywords” and “description” <meta> tag, if found.
|
||||
|
||||
2. Determine the function and purpose of the image by analyzing these elements. Take into account the purpose and function
|
||||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
where the level of certainty goes from 1 to 100 and 'sufficient certainty' means > 80
|
||||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
|
||||
|
||||
def parse_mllm_alt_text_response(mllm_response): #quella dentro utils_API
|
||||
"""
|
||||
Parse an MLLM response string and extract key attributes into a JSON object.
|
||||
|
||||
from mllm response like:
|
||||
```json\n{\n\"Original alt-text assessment\"... etc
|
||||
to a structured dictionary.
|
||||
|
||||
Args:
|
||||
mllm_response (str): The raw MLLM response text containing JSON data
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the extracted attributes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Handle NaN or None values
|
||||
if mllm_response is None or mllm_response == "":
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
# Extract JSON content between ```json and ``` markers
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
# Try to find JSON without markdown code blocks
|
||||
json_match = re.search(r'\{.*\}', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
print("No JSON match found in MLLM response.")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
json_str = json_match.group(1) if '```json' in mllm_response else json_match.group(0)
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Create a structured output with the key attributes
|
||||
result = {
|
||||
"original_alt_text_assessment": parsed_data.get("Original alt-text assessment", ""),
|
||||
"assessment": parsed_data.get("Assessment", ""),
|
||||
"evaluation_result": parsed_data.get("EvaluationResult", ""),
|
||||
"new_alt_text": parsed_data.get("New alt-text", "")
|
||||
}
|
||||
print("Parsed MLLM response:", result)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing MLLM response: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
def format_user_prompt(
|
||||
original_alt_text,
|
||||
html_context,
|
||||
page_title,
|
||||
page_description,
|
||||
page_keywords,
|
||||
):
|
||||
|
||||
alt_text = "Here is the alt-text of the image: " + str(original_alt_text)
|
||||
|
||||
HTML_context = "Here is the surrounding HTML context of the element: " + str(html_context)
|
||||
|
||||
page_text = "Here is the content of the page: Title of the page: " + str(page_title)
|
||||
|
||||
page_text = page_text + ", content of the <meta name='description'> tag: "+ str(page_description)
|
||||
|
||||
page_text = page_text+ ", content of the <meta name='keywords'> tag: "+ str(page_keywords)
|
||||
|
||||
user_prompt_to_use=alt_text + " " + HTML_context + " " + page_text
|
||||
return user_prompt_to_use
|
||||
|
||||
def download_hf_model(model_id, output_dir="./hf_model"):
|
||||
"""Download model from Hugging Face"""
|
||||
print(f"Downloading {model_id} from Hugging Face...")
|
||||
model_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
local_dir=output_dir,
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
print(f"Model downloaded to: {model_path}")
|
||||
return model_path
|
||||
|
||||
def convert_to_gguf(model_path, output_path="./model.gguf"):
|
||||
"""
|
||||
Convert model to GGUF format using llama.cpp
|
||||
|
||||
Note: You need llama.cpp installed and convert.py script
|
||||
Clone from: https://github.com/ggerganov/llama.cpp
|
||||
"""
|
||||
print("Converting to GGUF format...")
|
||||
|
||||
# This assumes you have llama.cpp cloned and convert.py available
|
||||
# Adjust the path to your llama.cpp installation
|
||||
convert_script = "./llama.cpp/convert_hf_to_gguf.py" # Path to llama.cpp convert.py
|
||||
|
||||
cmd = [
|
||||
"python", convert_script,
|
||||
model_path,
|
||||
"--outfile", output_path,
|
||||
"--outtype", "f16" # Use f16 for better quality, q4_0 for smaller size
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"GGUF model created: {output_path}")
|
||||
except FileNotFoundError:
|
||||
print("Error: llama.cpp convert.py not found.")
|
||||
print("Please clone llama.cpp: git clone https://github.com/ggerganov/llama.cpp")
|
||||
return None
|
||||
|
||||
return output_path
|
||||
|
||||
def create_modelfile(model_name, gguf_path, template=None):
|
||||
"""Create Ollama Modelfile"""
|
||||
modelfile_content = f"""FROM {gguf_path}
|
||||
|
||||
# Set parameters
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER top_p 0.9
|
||||
PARAMETER top_k 40
|
||||
|
||||
# Set the prompt template (adjust based on your model)
|
||||
TEMPLATE """
|
||||
|
||||
if template:
|
||||
modelfile_content += f'"""{template}"""'
|
||||
else:
|
||||
# Default template for chat models
|
||||
modelfile_content += '''"""{{ if .System }}System: {{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
{{ end }}Assistant: """'''
|
||||
|
||||
modelfile_path = model_name + "Modelfile"
|
||||
with open(modelfile_path, "w") as f:
|
||||
f.write(modelfile_content)
|
||||
|
||||
print(f"Modelfile created: {modelfile_path}")
|
||||
return modelfile_path
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def prepare_model_inputs(sample):
|
||||
# print("Preparing model inputs...")
|
||||
formatted_data = format_data(sample, is_inference=True)#in inferenza non gli passo la parte assistente (la risposta attesa)
|
||||
messages = formatted_data["messages"]
|
||||
image_inputs = process_vision_info(messages)
|
||||
return messages, image_inputs
|
||||
|
||||
def generate_description(sample, model, tokenizer):
|
||||
print("Generating description...")
|
||||
# Prepare the model inputs
|
||||
messages, image_inputs = prepare_model_inputs(
|
||||
sample
|
||||
) # can be avoided if already prepared
|
||||
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True) # manca max_length=8192, # Equivalent to num_ctx, max input token
|
||||
|
||||
### per max_length sembra vadano fatti i seguenti passaggi
|
||||
# First get the formatted string
|
||||
#formatted = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
# Then tokenize with max_length
|
||||
#tokenized = tokenizer(formatted, return_tensors="pt", max_length=8192, truncation=True, return_dict=True)
|
||||
|
||||
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda") #vedi che non gli passa immagine PIL come fa per gemma3
|
||||
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
|
||||
image_sizes = [tokenized["pixel_values"].shape[-2:]]
|
||||
print("Image sizes:", image_sizes)
|
||||
|
||||
|
||||
print(f"model device: {model.device}")
|
||||
tokenized= tokenized.to(model.device)#sposto tutto sul device del modello
|
||||
generation_config = {
|
||||
"temperature": 0.7, # Same as Ollama
|
||||
"max_new_tokens": 800, # Equivalent to num_predict
|
||||
"top_p": 0.95, # Same as Ollama
|
||||
"do_sample": True, # Required for temperature/top_p to work
|
||||
}
|
||||
output = model.generate(
|
||||
**tokenized,
|
||||
image_sizes=image_sizes,
|
||||
#max_new_tokens=512,
|
||||
**generation_config,
|
||||
)[0]
|
||||
|
||||
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
|
||||
print("Raw model output text:", type(decoded_output), decoded_output)
|
||||
parsed_resp = parse_mllm_alt_text_response(decoded_output)
|
||||
|
||||
parsed_resp["model_id"]="ministral3_3b"
|
||||
return parsed_resp
|
||||
|
||||
|
||||
|
||||
def format_data(sample, is_inference=False):
|
||||
|
||||
formatted_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format_user_prompt( original_alt_text=sample["original_alt_text"],
|
||||
html_context=sample["html_context"],
|
||||
page_title=sample["page_title"],
|
||||
page_description=sample["page_description"],
|
||||
page_keywords=sample["page_keywords"],
|
||||
|
||||
|
||||
),
|
||||
},
|
||||
|
||||
{"type": "image_url", "image_url": {"url": sample["image_url"]}},# passato come image url e non pil come gemma3
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
if is_inference:
|
||||
# print(
|
||||
# "formatted_data for inference:", formatted_data
|
||||
# ) # non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
||||
pass
|
||||
|
||||
else:
|
||||
formatted_data["messages"].append( #aggiungo la parte di risposta attesa per il training
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
|
||||
#{"type": "text", "text": sample["llm_alt_text"]} , #only alt-text atteso
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"""```json
|
||||
{{
|
||||
"Original alt-text assessment": {json.dumps(sample["llm_assessment"])},
|
||||
"Assessment": {json.dumps(sample["llm_judgment"])},
|
||||
"EvaluationResult": {json.dumps(sample["llm_evaluation_result"])},
|
||||
"New alt-text": {json.dumps(sample["llm_alt_text"])}
|
||||
}}
|
||||
```"""
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
)
|
||||
return formatted_data
|
||||
|
||||
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
print("Processing vision info...")
|
||||
image_inputs = []
|
||||
image_index = 0
|
||||
|
||||
# Iterate through each conversation
|
||||
for msg_index, msg in enumerate(messages):
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element_index, element in enumerate(content):
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
try:
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element.get("image")
|
||||
|
||||
# Convert to RGB if it's a PIL Image
|
||||
if isinstance(image, Image.Image):
|
||||
print(f"Image {image_index} - Original shape: {image.size}, mode: {image.mode}")
|
||||
|
||||
# Check for problematic dimensions
|
||||
if image.size[0] <= 1 or image.size[1] <= 1:
|
||||
print(f"⚠️ WARNING: Image {image_index} has very small dimensions: {image.size}. Skipping or resizing...")
|
||||
# Option 1: Skip the image
|
||||
#continue
|
||||
# Option 2: Resize to minimum viable size
|
||||
image = image.resize((224, 224))
|
||||
print(f"Resized image {image_index} to: {image.size}")
|
||||
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
print(f"Converted image {image_index} to RGB mode")
|
||||
|
||||
image_inputs.append(image)
|
||||
image_index += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR processing image {image_index} at message[{msg_index}].content[{element_index}]")
|
||||
print(f" Error details: {type(e).__name__}: {e}")
|
||||
if isinstance(image, Image.Image):
|
||||
print(f" Image properties - Size: {image.size}, Mode: {image.mode}")
|
||||
continue
|
||||
|
||||
print(f"Successfully processed {len(image_inputs)} images")
|
||||
return image_inputs
|
||||
|
||||
print("Loading dataset...")
|
||||
# Load dataset from the hub
|
||||
|
||||
dataset = load_dataset("nicolaleo/LLM-alt-text-assessment-full-features", split="train",cache_dir="./dataset_cache")
|
||||
|
||||
dataset_validation = load_dataset("nicolaleo/LLM-alt-text-assessment-full-features_validation", split="validation",cache_dir="./dataset_cache")
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
dataset_copy=deepcopy(dataset)
|
||||
|
||||
|
||||
|
||||
# Convert dataset to OAI messages
|
||||
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
|
||||
dataset = [format_data(sample, is_inference=False) for sample in dataset]
|
||||
|
||||
dataset_validation = [format_data(sample, is_inference=False) for sample in dataset_validation]
|
||||
|
||||
|
||||
print(dataset[0]["messages"])
|
||||
print("parse_mllm_alt_text_response:",parse_mllm_alt_text_response(dataset[0]["messages"][2]["content"][0]["text"]))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
torch.cuda.get_device_capability()
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Get free memory in bytes
|
||||
free_memory = torch.cuda.mem_get_info()[0]
|
||||
total_memory = torch.cuda.mem_get_info()[1]
|
||||
|
||||
# Convert to GB for readability
|
||||
free_gb = free_memory / (1024**3)
|
||||
total_gb = total_memory / (1024**3)
|
||||
|
||||
print(f"Free: {free_gb:.2f} GB / Total: {total_gb:.2f} GB")
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
|
||||
|
||||
# Hugging Face model id
|
||||
model_id = "mistralai/Ministral-3-3B-Instruct-2512"#"google/gemma-3-4b-pt"#"google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`
|
||||
|
||||
# Check if GPU benefits from bfloat16
|
||||
#if torch.cuda.get_device_capability()[0] < 8:
|
||||
# raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")
|
||||
|
||||
# Define model init arguments
|
||||
model_kwargs = dict(
|
||||
#attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
||||
dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
||||
device_map="auto", # Let torch decide how to load the model
|
||||
|
||||
)
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
"""model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=model_kwargs["dtype"],
|
||||
bnb_4bit_quant_storage=model_kwargs["dtype"],
|
||||
)"""
|
||||
|
||||
# Load model and tokenizer
|
||||
#model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
|
||||
#processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
||||
|
||||
|
||||
|
||||
|
||||
# Set the cache directory to current folder
|
||||
cache_dir = "./model_cache" # or just "." for directly in current folder
|
||||
|
||||
"""
|
||||
print("Loading model... This may take a while.")
|
||||
model=AutoModelForImageTextToText.from_pretrained(# versione quantizzata 4bit-bf16
|
||||
model_id,
|
||||
cache_dir=cache_dir,
|
||||
**model_kwargs
|
||||
)
|
||||
print("Model loaded.")
|
||||
print(f"Original Model dtype: {model.dtype}") # ritorna torch.bfloat16
|
||||
|
||||
proc_cache_dir = "./proc_cache"
|
||||
print("Loading processor...")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id,#"google/gemma-3-4b-it",#model_id, # nel file originale prende -it e non -pt (cambia poco comunque)
|
||||
cache_dir=proc_cache_dir
|
||||
)
|
||||
print("Processor loaded.")
|
||||
"""
|
||||
|
||||
proc_cache_dir = "./proc_cache"
|
||||
tokenizer = MistralCommonBackend.from_pretrained(model_id,cache_dir=proc_cache_dir)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(model_id,cache_dir=cache_dir,**model_kwargs) # quantizzazione b&b non va perchè modello gia sui parametri di quantizzazione FP()
|
||||
|
||||
|
||||
print("testing the loaded model...")
|
||||
# generate the description
|
||||
description = generate_description(dataset_copy[0], model, tokenizer)
|
||||
print("text generated:",description)
|
||||
|
||||
|
||||
# Download and save to current folder
|
||||
print("Saving model and processor locally...")
|
||||
save_path = "./original_local_model_"+model_id.replace("/", "_")
|
||||
#model.save_pretrained(save_path)
|
||||
#processor.save_pretrained(save_path)
|
||||
print("Model and processor saved.")
|
||||
|
||||
|
||||
""" # la convesrione in ollama funziona solo se fatta su modello non quantizzato (da capire se si può fare su modello 4bit)
|
||||
print("Converting and importing model to Ollama...")
|
||||
# Step 1: Download from Hugging Face
|
||||
model_path= "./original_local_model_ollama"
|
||||
model_path = download_hf_model(model_id,output_dir=model_path)
|
||||
|
||||
# Step 2: Convert to GGUF (requires llama.cpp)
|
||||
gguf_path = convert_to_gguf(model_path, "./gemma.gguf")
|
||||
|
||||
if gguf_path:
|
||||
# Step 3: Create Modelfile
|
||||
OLLAMA_MODEL_NAME = "gemma3-wcag"
|
||||
modelfile = create_modelfile(OLLAMA_MODEL_NAME, gguf_path)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=8,#16,
|
||||
lora_dropout=0.05,
|
||||
r=8,#16,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
#modules_to_save=[ #quello che mi prendeva memoria in più
|
||||
# "lm_head",
|
||||
# "embed_tokens",
|
||||
#],
|
||||
)
|
||||
|
||||
from trl import SFTConfig
|
||||
|
||||
args = SFTConfig(
|
||||
output_dir="./ministral-finetuned-wcag_"+model_id.replace("/", "_"), # directory to save and repository id
|
||||
num_train_epochs=1, # number of training epochs
|
||||
per_device_train_batch_size=1, # batch size per device during training
|
||||
per_device_eval_batch_size=1, # batch size for evaluation
|
||||
gradient_accumulation_steps=2,#4, # number of steps before performing a backward/update pass
|
||||
gradient_checkpointing=True, # use gradient checkpointing to save memory
|
||||
optim="adamw_8bit",#"adamw_torch_fused", # use fused adamw optimizer
|
||||
logging_steps=5, # log every 5 steps
|
||||
save_strategy="epoch", # save checkpoint every epoch
|
||||
#eval_strategy="epoch", # evaluate checkpoint every epoch
|
||||
eval_strategy="steps",
|
||||
eval_steps=5, # Evaluate every 5 step
|
||||
learning_rate=2e-4, # learning rate, based on QLoRA paper
|
||||
bf16=True,#False,#True, # use bfloat16 precision
|
||||
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
|
||||
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
|
||||
lr_scheduler_type="constant", # use constant learning rate scheduler
|
||||
push_to_hub=True, # push model to hub
|
||||
report_to="tensorboard", # report metrics to tensorboard
|
||||
gradient_checkpointing_kwargs={
|
||||
"use_reentrant": False
|
||||
}, # use reentrant checkpointing
|
||||
dataset_text_field="", # need a dummy field for collator
|
||||
dataset_kwargs={"skip_prepare_dataset": True}, # important for collator
|
||||
)
|
||||
args.remove_unused_columns = False # important for collator
|
||||
|
||||
# Create a data collator to encode text and image pairs
|
||||
def collate_fn(examples): # NB da sistemare, non ho più processor ma solo tokenizer, da capire se è un problema per il collate_fn
|
||||
# examples: list of formatted_data dicts with a "messages" field
|
||||
messages_list = [example["messages"] for example in examples]
|
||||
|
||||
# Use the tokenizer's chat template to produce tensors suitable for Mistral3
|
||||
# It should return a dict with input_ids, attention_mask and (for multimodal) pixel_values
|
||||
tokenized = tokenizer.apply_chat_template(
|
||||
messages_list,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
padding=True,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
# Prepare labels from input_ids and mask padding tokens
|
||||
labels = tokenized["input_ids"].clone()
|
||||
|
||||
# Determine pad token id
|
||||
pad_id = None
|
||||
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||
pad_id = tokenizer.pad_token_id
|
||||
elif hasattr(tokenizer, "tokenizer") and getattr(tokenizer.tokenizer, "pad_token_id", None) is not None:
|
||||
pad_id = tokenizer.tokenizer.pad_token_id
|
||||
|
||||
if pad_id is not None:
|
||||
labels[labels == pad_id] = -100
|
||||
|
||||
# Mask special image tokens if tokenizer exposes them
|
||||
try:
|
||||
special_map = getattr(tokenizer, "tokenizer", None)
|
||||
if special_map and getattr(special_map, "special_tokens_map", None):
|
||||
boi_tok = special_map.special_tokens_map.get("boi_token")
|
||||
if boi_tok:
|
||||
boi_id = special_map.convert_tokens_to_ids(boi_tok)
|
||||
labels[labels == boi_id] = -100
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tokenized["labels"] = labels
|
||||
|
||||
# Free CUDA memory between batches to reduce chance of OOM
|
||||
try:
|
||||
import torch as _torch
|
||||
if _torch.cuda.is_available():
|
||||
_torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
print("Error clearing CUDA cache, continuing without clearing.")
|
||||
|
||||
return tokenized
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(# da sistemare da errori
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=dataset_validation,
|
||||
peft_config=peft_config,
|
||||
processing_class=tokenizer, # use tokenizer as processing class for Mistral3
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
# Start training, the model will be automatically saved to the Hub and the output directory
|
||||
trainer.train()
|
||||
|
||||
print("Training completed.")
|
||||
# Save the final model again to the Hugging Face Hub
|
||||
try:
|
||||
trainer.save_model()# lo salva (il Lora) su HF col nome specificato in args.output_dir
|
||||
print("Final model saved to Hugging Face Hub.")
|
||||
except Exception as e:
|
||||
print(f"Error saving model to Hugging Face Hub: {e}")
|
||||
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
# Access the log history
|
||||
log_history = trainer.state.log_history
|
||||
|
||||
# Extract training / validation loss
|
||||
train_losses = [log["loss"] for log in log_history if "loss" in log]
|
||||
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
|
||||
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
|
||||
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]
|
||||
|
||||
# Plot the training loss
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
|
||||
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss")
|
||||
plt.title("Training and Validation Loss per Epoch")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig("training_validation_loss.png", dpi=300, bbox_inches='tight')
|
||||
print("Plot saved successfully as 'training_validation_loss.png'")
|
||||
except Exception as e:
|
||||
print(f"Error plotting loss curves: {e}")
|
||||
|
||||
|
||||
# free the memory again
|
||||
del model
|
||||
del trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from peft import PeftModel
|
||||
|
||||
# Load Model base model
|
||||
#model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, cache_dir=cache_dir) # versione bf16 (non quantizzata 4bit) NB: quantization happens at loading time,
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(model_id, device_map="auto",cache_dir=cache_dir,dtype=torch.bfloat16)
|
||||
|
||||
print(f"Original Model dtype pre merge: {model.dtype}")
|
||||
|
||||
|
||||
|
||||
# Merge LoRA and base model and save
|
||||
peft_model = PeftModel.from_pretrained(model, args.output_dir)
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
merged_model.save_pretrained("merged_model_"+model_id.replace("/", "_"), safe_serialization=True, max_shard_size="2GB")
|
||||
|
||||
|
||||
tokenizer = MistralCommonBackend.from_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained("merged_model_"+model_id.replace("/", "_"))
|
||||
|
||||
|
||||
print("Loading merged model for inference...")
|
||||
|
||||
merged_model_path = "merged_model_"+model_id.replace("/", "_")
|
||||
|
||||
"""
|
||||
# Load the merged model
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
merged_model_path,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
#attn_implementation="eager",
|
||||
)
|
||||
"""
|
||||
#print(f"Merged Model dtype: {model.dtype}")
|
||||
#processor = AutoProcessor.from_pretrained(merged_model_path)
|
||||
|
||||
tokenizer = MistralCommonBackend.from_pretrained(merged_model_path)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(merged_model_path, device_map="auto",dtype=torch.bfloat16)
|
||||
|
||||
|
||||
print("testing the merged model...")
|
||||
|
||||
|
||||
# generate the description
|
||||
description = generate_description(dataset_copy[0], model, tokenizer)
|
||||
print("text generated:",description)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,437 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText,BitsAndBytesConfig
|
||||
import gc
|
||||
import requests
|
||||
import pandas as pd
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
|
||||
|
||||
system_message = """You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for
|
||||
images on webpages are appropriate according to WCAG guidelines. The alt-text should serve the same purpose and present
|
||||
the same information as the original image content. As a result, it is possible to remove the image content and replace it with the text alternative and no functionality or information would be lost. This text alternative should not necessarily describe the image content.
|
||||
It should serve the same purpose and convey the same information. This may sometimes result in a text alternative that looks like a description of the image content. But this would only be true if that was the best way to serve the same purpose.
|
||||
If possible, the short text alternative should completely convey the purpose and information. If it is not possible to do this in a short phrase or sentence, then the short text alternative should provide a brief overview of the information.
|
||||
The text alternative should be able to substitute for the image content. If the image content were removed from the page and substituted with the text, the page would still provide the same function and information. The text alternative would be brief but as informative as possible.
|
||||
In deciding what text to include in the alternative, it is often a good idea to consider the following questions:
|
||||
Why is this image content here?
|
||||
What information is it presenting?
|
||||
What purpose does it fulfill?
|
||||
If I could not use the image content, what words would I use to convey the same function and/or information?
|
||||
|
||||
When image content contains words that are important to understanding the content, the alt text should include those words.
|
||||
Decorative images don’t add information to the content of a page. For example, the information provided by the image might already be given using adjacent text, or the image might be included to make the website more visually attractive.
|
||||
In these cases, a null (empty) alt text should be provided (alt="") so that they can be ignored by assistive technologies, such as screen readers.
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
- The image found on the webpage.
|
||||
- The associated alternative text. When the alt-text is empty or absent, you will be explicitly informed.
|
||||
- The surrounding context of the image.
|
||||
- The page title, headings and the content of the “keywords” and “description” <meta> tag, if found.
|
||||
|
||||
2. Determine the function and purpose of the image by analyzing these elements. Take into account the purpose and function
|
||||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
where the level of certainty goes from 1 to 100 and 'sufficient certainty' means > 80
|
||||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
|
||||
|
||||
|
||||
def parse_mllm_alt_text_response(mllm_response): #quella dentro utils_API
|
||||
"""
|
||||
Parse an MLLM response string and extract key attributes into a JSON object.
|
||||
|
||||
from mllm response like:
|
||||
```json\n{\n\"Original alt-text assessment\"... etc
|
||||
to a structured dictionary.
|
||||
|
||||
Args:
|
||||
mllm_response (str): The raw MLLM response text containing JSON data
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the extracted attributes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Handle NaN or None values
|
||||
if mllm_response is None or mllm_response == "":
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
# Extract JSON content between ```json and ``` markers
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
# Try to find JSON without markdown code blocks
|
||||
json_match = re.search(r'\{.*\}', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
print("No JSON match found in MLLM response.")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
json_str = json_match.group(1) if '```json' in mllm_response else json_match.group(0)
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Create a structured output with the key attributes
|
||||
result = {
|
||||
"original_alt_text_assessment": parsed_data.get("Original alt-text assessment", ""),
|
||||
"assessment": parsed_data.get("Assessment", ""),
|
||||
"evaluation_result": parsed_data.get("EvaluationResult", ""),
|
||||
"new_alt_text": parsed_data.get("New alt-text", "")
|
||||
}
|
||||
print("Parsed MLLM response:", result)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing MLLM response: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
|
||||
def format_user_prompt(
|
||||
original_alt_text,
|
||||
html_context,
|
||||
page_title,
|
||||
page_description,
|
||||
page_keywords,
|
||||
):
|
||||
|
||||
alt_text = "Here is the alt-text of the image: " + str(original_alt_text)
|
||||
|
||||
HTML_context = "Here is the surrounding HTML context of the element: " + str(html_context)
|
||||
|
||||
page_text = "Here is the content of the page: Title of the page: " + str(page_title)
|
||||
|
||||
page_text = page_text + ", content of the <meta name='description'> tag: "+ str(page_description)
|
||||
|
||||
page_text = page_text+ ", content of the <meta name='keywords'> tag: "+ str(page_keywords)
|
||||
|
||||
user_prompt_to_use=alt_text + " " + HTML_context + " " + page_text
|
||||
return user_prompt_to_use
|
||||
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
# print("Processing vision info...")
|
||||
image_inputs = []
|
||||
# Iterate through each conversation
|
||||
for msg in messages:
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
image_inputs.append(image.convert("RGB")) # converte in rgb !
|
||||
return image_inputs
|
||||
|
||||
|
||||
|
||||
def format_data(sample, is_inference=False):
|
||||
|
||||
formatted_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format_user_prompt( original_alt_text=sample["original_alt_text"],
|
||||
html_context=sample["html_context"],
|
||||
page_title=sample["page_title"],
|
||||
page_description=sample["page_description"],
|
||||
page_keywords=sample["page_keywords"],
|
||||
|
||||
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": Image.open(
|
||||
requests.get(sample["image_url"], stream=True).raw
|
||||
).convert(
|
||||
"RGB"
|
||||
), # .convert("RGB") necessario??
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
if is_inference:
|
||||
# print(
|
||||
# "formatted_data for inference:", formatted_data
|
||||
# ) # non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
||||
pass
|
||||
|
||||
else:
|
||||
formatted_data["messages"].append( #aggiungo la parte di risposta attesa per il training
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
|
||||
#{"type": "text", "text": sample["llm_alt_text"]} , #only alt-text atteso
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"""```json
|
||||
{{
|
||||
"Original alt-text assessment": {json.dumps(sample["llm_assessment"])},
|
||||
"Assessment": {json.dumps(sample["llm_judgment"])},
|
||||
"EvaluationResult": {json.dumps(sample["llm_evaluation_result"])},
|
||||
"New alt-text": {json.dumps(sample["llm_alt_text"])}
|
||||
}}
|
||||
```"""
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
)
|
||||
return formatted_data
|
||||
|
||||
def prepare_model_inputs(sample):
|
||||
# print("Preparing model inputs...")
|
||||
formatted_data = format_data(sample, is_inference=True)
|
||||
messages = formatted_data["messages"]
|
||||
image_inputs = process_vision_info(messages)
|
||||
return messages, image_inputs
|
||||
|
||||
|
||||
def generate_description(sample, model, processor):
|
||||
print("Generating description...")
|
||||
# Prepare the model inputs
|
||||
messages, image_inputs = prepare_model_inputs(
|
||||
sample
|
||||
) # can be avoided if already prepared
|
||||
# Apply the chat template to get the final text input
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs, # PIL Image or list of images
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=8192, # Equivalent to num_ctx, max input token
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
# Generate the output
|
||||
stop_token_ids = [
|
||||
processor.tokenizer.eos_token_id,
|
||||
processor.tokenizer.convert_tokens_to_ids("<end_of_turn>"),
|
||||
]
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.7, # Same as Ollama
|
||||
"max_new_tokens": 800, # Equivalent to num_predict
|
||||
"top_p": 0.95, # Same as Ollama
|
||||
"do_sample": True, # Required for temperature/top_p to work
|
||||
}
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
**generation_config,
|
||||
eos_token_id=stop_token_ids,
|
||||
disable_compile=True,
|
||||
)
|
||||
# Trim the generation and decode the output to text
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
print("Raw model output text:", type(output_text[0]), output_text[0])
|
||||
parsed_resp = parse_mllm_alt_text_response(output_text[0])
|
||||
parsed_resp["model_id"]="gemma3_4b"
|
||||
return parsed_resp
|
||||
|
||||
|
||||
def process_to_prepare_model_inputs(row):
|
||||
try:
|
||||
result = prepare_model_inputs(
|
||||
sample=row,
|
||||
)
|
||||
return pd.Series(result)
|
||||
except Exception as e:
|
||||
print(f"Error process_to_prepare_model_inputs processing row {row.name}: {e}")
|
||||
return pd.Series({"llm_textual_input": None, "llm_image_input": None})
|
||||
|
||||
|
||||
def process_to_generate_model_outputs(row):
|
||||
try:
|
||||
result = generate_description(
|
||||
sample=row,
|
||||
model=model,
|
||||
processor=processor,
|
||||
)
|
||||
return pd.Series(result)
|
||||
except Exception as e:
|
||||
print(f"Error processing row {row.name}: {e}")
|
||||
return pd.Series({
|
||||
'original_alt_text_assessment': None,
|
||||
'assessment': None,
|
||||
'evaluation_result': None,
|
||||
'new_alt_text': None,
|
||||
'model_id':None
|
||||
})
|
||||
|
||||
|
||||
df_esercitazione = pd.read_csv(
|
||||
"esercitazione_12_2025/dataset_esercitazione.csv", sep=";"
|
||||
)
|
||||
|
||||
# df_esercitazione[['llm_textual_input',"llm_image_input"]] = df_esercitazione.head(2).apply(process_to_prepare_model_inputs, axis=1)
|
||||
|
||||
df_esercitazione[["llm_textual_input", "llm_image_input"]] = df_esercitazione.apply(
|
||||
process_to_prepare_model_inputs, axis=1
|
||||
)
|
||||
|
||||
df_esercitazione[["llm_textual_input", "llm_image_input"]].to_csv(
|
||||
"llm_model_inputs.csv", sep=";", index=False
|
||||
)
|
||||
|
||||
|
||||
# per modello scaricato da HF
|
||||
model_id = "google/gemma-3-4b-it" #se voglio scaricarlo da HF
|
||||
cache_dir = "./model_cache"
|
||||
proc_cache_dir = "./proc_cache"
|
||||
#######
|
||||
|
||||
output_dir = "merged_model_google_gemma-3-4b-it" #per il modello finetunato locale fp16
|
||||
#output_dir="original_local_model_google_gemma-3-4b-it" # versione originale locale quantizzata 4b-bf16 del modello
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model_from_hf = False # False per modello locale
|
||||
use_4bit_quantization = False # Set to True to enable 4-bit quantization
|
||||
|
||||
model_kwargs={}
|
||||
if use_4bit_quantization:
|
||||
# BitsAndBytesConfig int-4 config
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_storage=torch.bfloat16,
|
||||
)
|
||||
else:
|
||||
print("Not using 4-bit quantization. Model will be loaded in full precision (bfloat16).")
|
||||
|
||||
|
||||
|
||||
if model_from_hf:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
#attn_implementation="eager",
|
||||
cache_dir=cache_dir,
|
||||
**model_kwargs
|
||||
)
|
||||
print("\n Model loaded from HF")
|
||||
processor = AutoProcessor.from_pretrained(model_id,cache_dir=proc_cache_dir)
|
||||
print("Processor loaded from HF")
|
||||
else:
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
output_dir, # output_dir local model
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
#attn_implementation="eager",
|
||||
**model_kwargs
|
||||
)
|
||||
print("\n Model loaded from local directory")
|
||||
processor = AutoProcessor.from_pretrained(output_dir)
|
||||
|
||||
print("Processor loaded from local directory")
|
||||
|
||||
|
||||
print(f"Model dtype: {model.dtype}")
|
||||
memory_bytes = model.get_memory_footprint()
|
||||
memory_gb = memory_bytes / (1024**3)
|
||||
memory_mb = memory_bytes / (1024**2)
|
||||
|
||||
print(f"Model memory footprint: {memory_bytes:,} bytes")
|
||||
print(f"Model memory footprint: {memory_gb:.2f} GB")
|
||||
print(f"Model memory footprint: {memory_mb:.2f} MB")
|
||||
|
||||
|
||||
# And might contain an image processor
|
||||
if hasattr(processor, "image_processor"):
|
||||
print("the processor Has image processor")
|
||||
|
||||
# df_esercitazione[["llm_alt_text_1"]] = df_esercitazione.head(2).apply(
|
||||
# process_to_generate_model_outputs, axis=1
|
||||
# )
|
||||
|
||||
#df_esercitazione[["llm_alt_text_1"]] = df_esercitazione.apply(
|
||||
# process_to_generate_model_outputs, axis=1
|
||||
#)
|
||||
|
||||
#df_esercitazione[['llm_assessment_1', 'llm_judgment_1', 'llm_evaluation_result_1', 'llm_alt_text_1','llm_model_1']] = df_esercitazione.head(2).apply(process_to_generate_model_outputs, axis=1)
|
||||
|
||||
df_esercitazione[['llm_assessment_1', 'llm_judgment_1', 'llm_evaluation_result_1', 'llm_alt_text_1','llm_model_1']] = df_esercitazione.apply(process_to_generate_model_outputs, axis=1)
|
||||
|
||||
#df_esercitazione.to_csv("hf_llm_generated_output_"+"original_local_model_google_gemma-3-4b-it"+".csv", sep=";", index=False)
|
||||
df_esercitazione.to_csv("hf_llm_generated_output_"+"merged_model_google_gemma-3-4b-it-4bit"+".csv", sep=";", index=False)
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
|
||||
import gc
|
||||
import requests
|
||||
import pandas as pd
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
|
||||
|
||||
system_message = """You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for
|
||||
images on webpages are appropriate according to WCAG guidelines. The alt-text should serve the same purpose and present
|
||||
the same information as the original image content. As a result, it is possible to remove the image content and replace it with the text alternative and no functionality or information would be lost. This text alternative should not necessarily describe the image content.
|
||||
It should serve the same purpose and convey the same information. This may sometimes result in a text alternative that looks like a description of the image content. But this would only be true if that was the best way to serve the same purpose.
|
||||
If possible, the short text alternative should completely convey the purpose and information. If it is not possible to do this in a short phrase or sentence, then the short text alternative should provide a brief overview of the information.
|
||||
The text alternative should be able to substitute for the image content. If the image content were removed from the page and substituted with the text, the page would still provide the same function and information. The text alternative would be brief but as informative as possible.
|
||||
In deciding what text to include in the alternative, it is often a good idea to consider the following questions:
|
||||
Why is this image content here?
|
||||
What information is it presenting?
|
||||
What purpose does it fulfill?
|
||||
If I could not use the image content, what words would I use to convey the same function and/or information?
|
||||
|
||||
When image content contains words that are important to understanding the content, the alt text should include those words.
|
||||
Decorative images don’t add information to the content of a page. For example, the information provided by the image might already be given using adjacent text, or the image might be included to make the website more visually attractive.
|
||||
In these cases, a null (empty) alt text should be provided (alt="") so that they can be ignored by assistive technologies, such as screen readers.
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
- The image found on the webpage.
|
||||
- The associated alternative text. When the alt-text is empty or absent, you will be explicitly informed.
|
||||
- The surrounding context of the image.
|
||||
- The page title, headings and the content of the “keywords” and “description” <meta> tag, if found.
|
||||
|
||||
2. Determine the function and purpose of the image by analyzing these elements. Take into account the purpose and function
|
||||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
where the level of certainty goes from 1 to 100 and 'sufficient certainty' means > 80
|
||||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
{"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"}"""
|
||||
|
||||
|
||||
|
||||
def parse_mllm_alt_text_response(mllm_response): #quella dentro utils_API
|
||||
"""
|
||||
Parse an MLLM response string and extract key attributes into a JSON object.
|
||||
|
||||
from mllm response like:
|
||||
```json\n{\n\"Original alt-text assessment\"... etc
|
||||
to a structured dictionary.
|
||||
|
||||
Args:
|
||||
mllm_response (str): The raw MLLM response text containing JSON data
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the extracted attributes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Handle NaN or None values
|
||||
if mllm_response is None or mllm_response == "":
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
# Extract JSON content between ```json and ``` markers
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
# Try to find JSON without markdown code blocks
|
||||
json_match = re.search(r'\{.*\}', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
print("No JSON match found in MLLM response.")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
json_str = json_match.group(1) if '```json' in mllm_response else json_match.group(0)
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Create a structured output with the key attributes
|
||||
result = {
|
||||
"original_alt_text_assessment": parsed_data.get("Original alt-text assessment", ""),
|
||||
"assessment": parsed_data.get("Assessment", ""),
|
||||
"evaluation_result": parsed_data.get("EvaluationResult", ""),
|
||||
"new_alt_text": parsed_data.get("New alt-text", "")
|
||||
}
|
||||
print("Parsed MLLM response:", result)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing MLLM response: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
|
||||
def format_user_prompt(
|
||||
original_alt_text,
|
||||
html_context,
|
||||
page_title,
|
||||
page_description,
|
||||
page_keywords,
|
||||
):
|
||||
|
||||
alt_text = "Here is the alt-text of the image: " + str(original_alt_text)
|
||||
|
||||
HTML_context = "Here is the surrounding HTML context of the element: " + str(html_context)
|
||||
|
||||
page_text = "Here is the content of the page: Title of the page: " + str(page_title)
|
||||
|
||||
page_text = page_text + ", content of the <meta name='description'> tag: "+ str(page_description)
|
||||
|
||||
page_text = page_text+ ", content of the <meta name='keywords'> tag: "+ str(page_keywords)
|
||||
|
||||
user_prompt_to_use=alt_text + " " + HTML_context + " " + page_text
|
||||
return user_prompt_to_use
|
||||
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
# print("Processing vision info...")
|
||||
image_inputs = []
|
||||
# Iterate through each conversation
|
||||
for msg in messages:
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
image_inputs.append(image.convert("RGB")) # converte in rgb !
|
||||
return image_inputs
|
||||
|
||||
|
||||
|
||||
def format_data(sample, is_inference=False):
|
||||
|
||||
formatted_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format_user_prompt( original_alt_text=sample["original_alt_text"],
|
||||
html_context=sample["html_context"],
|
||||
page_title=sample["page_title"],
|
||||
page_description=sample["page_description"],
|
||||
page_keywords=sample["page_keywords"],
|
||||
|
||||
|
||||
),
|
||||
},
|
||||
|
||||
{"type": "image_url", "image_url": {"url": sample["image_url"]}},# passato come image url e non pil come gemma3
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
if is_inference:
|
||||
# print(
|
||||
# "formatted_data for inference:", formatted_data
|
||||
# ) # non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
||||
pass
|
||||
|
||||
else:
|
||||
formatted_data["messages"].append( #aggiungo la parte di risposta attesa per il training
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
|
||||
#{"type": "text", "text": sample["llm_alt_text"]} , #only alt-text atteso
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"""```json
|
||||
{{
|
||||
"Original alt-text assessment": {json.dumps(sample["llm_assessment"])},
|
||||
"Assessment": {json.dumps(sample["llm_judgment"])},
|
||||
"EvaluationResult": {json.dumps(sample["llm_evaluation_result"])},
|
||||
"New alt-text": {json.dumps(sample["llm_alt_text"])}
|
||||
}}
|
||||
```"""
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
)
|
||||
return formatted_data
|
||||
|
||||
def prepare_model_inputs(sample):
|
||||
# print("Preparing model inputs...")
|
||||
formatted_data = format_data(sample, is_inference=True)
|
||||
messages = formatted_data["messages"]
|
||||
image_inputs = process_vision_info(messages)
|
||||
return messages, image_inputs
|
||||
|
||||
|
||||
def generate_description(sample, model, tokenizer):
|
||||
print("Generating description...")
|
||||
# Prepare the model inputs
|
||||
messages, image_inputs = prepare_model_inputs(
|
||||
sample
|
||||
)
|
||||
|
||||
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True) # manca max_length=8192, # Equivalent to num_ctx, max input token
|
||||
|
||||
### per max_length sembra vadano fatti i seguenti passaggi
|
||||
# First get the formatted string
|
||||
#formatted = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
# Then tokenize with max_length
|
||||
#tokenized = tokenizer(formatted, return_tensors="pt", max_length=8192, truncation=True, return_dict=True)
|
||||
|
||||
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda") #vedi che non gli passa immagine PIL come fa per gemma3
|
||||
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
|
||||
image_sizes = [tokenized["pixel_values"].shape[-2:]]
|
||||
print("Image sizes:", image_sizes)
|
||||
|
||||
|
||||
print(f"model device: {model.device}")
|
||||
tokenized= tokenized.to(model.device)#sposto tutto sul device del modello
|
||||
generation_config = {
|
||||
"temperature": 0.7, # Same as Ollama
|
||||
"max_new_tokens": 800, # Equivalent to num_predict
|
||||
"top_p": 0.95, # Same as Ollama
|
||||
"do_sample": True, # Required for temperature/top_p to work
|
||||
}
|
||||
output = model.generate(
|
||||
**tokenized,
|
||||
image_sizes=image_sizes,
|
||||
#max_new_tokens=512,
|
||||
**generation_config,
|
||||
)[0]
|
||||
|
||||
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
|
||||
print("Raw model output text:", type(decoded_output), decoded_output)
|
||||
parsed_resp = parse_mllm_alt_text_response(decoded_output)
|
||||
parsed_resp["model_id"]="ministral3_3b"
|
||||
return parsed_resp
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def process_to_prepare_model_inputs(row):
|
||||
try:
|
||||
result = prepare_model_inputs(
|
||||
sample=row,
|
||||
)
|
||||
return pd.Series(result)
|
||||
except Exception as e:
|
||||
print(f"Error process_to_prepare_model_inputs processing row {row.name}: {e}")
|
||||
return pd.Series({"llm_textual_input": None, "llm_image_input": None})
|
||||
|
||||
|
||||
def process_to_generate_model_outputs(row):
|
||||
try:
|
||||
result = generate_description(
|
||||
sample=row,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
return pd.Series(result)
|
||||
except Exception as e:
|
||||
print(f"Error processing row {row.name}: {e}")
|
||||
return pd.Series({
|
||||
'original_alt_text_assessment': None,
|
||||
'assessment': None,
|
||||
'evaluation_result': None,
|
||||
'new_alt_text': None,
|
||||
'model_id':None
|
||||
})
|
||||
|
||||
|
||||
df_esercitazione = pd.read_csv(
|
||||
"esercitazione_12_2025/dataset_esercitazione.csv", sep=";"
|
||||
)
|
||||
|
||||
# df_esercitazione[['llm_textual_input',"llm_image_input"]] = df_esercitazione.head(2).apply(process_to_prepare_model_inputs, axis=1)
|
||||
|
||||
df_esercitazione[["llm_textual_input", "llm_image_input"]] = df_esercitazione.apply(
|
||||
process_to_prepare_model_inputs, axis=1
|
||||
)
|
||||
|
||||
df_esercitazione[["llm_textual_input", "llm_image_input"]].to_csv(
|
||||
"llm_model_inputs.csv", sep=";", index=False
|
||||
)
|
||||
|
||||
|
||||
model_id = "mistralai/Ministral-3-3B-Instruct-2512"
|
||||
cache_dir = "./model_cache"
|
||||
tokenizer = MistralCommonBackend.from_pretrained(model_id,cache_dir=cache_dir)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(model_id, device_map="auto",cache_dir=cache_dir,dtype=torch.bfloat16)
|
||||
|
||||
print(f"Model device: {next(model.parameters()).device}")
|
||||
print(f"Model dtype: {model.dtype}")
|
||||
memory_bytes = model.get_memory_footprint()
|
||||
memory_gb = memory_bytes / (1024**3)
|
||||
memory_mb = memory_bytes / (1024**2)
|
||||
|
||||
print(f"Model memory footprint: {memory_bytes:,} bytes")
|
||||
print(f"Model memory footprint: {memory_gb:.2f} GB")
|
||||
print(f"Model memory footprint: {memory_mb:.2f} MB")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
df_esercitazione[['llm_assessment_1', 'llm_judgment_1', 'llm_evaluation_result_1', 'llm_alt_text_1','llm_model_1']] = df_esercitazione.head(2).apply(process_to_generate_model_outputs, axis=1)
|
||||
|
||||
#df_esercitazione[['llm_assessment_1', 'llm_judgment_1', 'llm_evaluation_result_1', 'llm_alt_text_1','llm_model_1']] = df_esercitazione.apply(process_to_generate_model_outputs, axis=1)
|
||||
|
||||
#df_esercitazione.to_csv("hf_llm_generated_output_"+"original_local_model_ministral-3-3b-it"+".csv", sep=";", index=False)
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,54 @@
|
|||
import torch
|
||||
from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend
|
||||
|
||||
import gc
|
||||
|
||||
cache_dir = "./model_cache"
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model_id = "mistralai/Ministral-3-3B-Instruct-2512"
|
||||
|
||||
tokenizer = MistralCommonBackend.from_pretrained(model_id,cache_dir=cache_dir)
|
||||
model = Mistral3ForConditionalGeneration.from_pretrained(model_id, device_map="auto",cache_dir=cache_dir)
|
||||
|
||||
print(f"Model device: {next(model.parameters()).device}")
|
||||
print(f"Model dtype: {model.dtype}")
|
||||
|
||||
|
||||
image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
|
||||
|
||||
tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda") #vedi che non gli passa immagine PIL come fa per gemma3
|
||||
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
|
||||
image_sizes = [tokenized["pixel_values"].shape[-2:]]
|
||||
print("Image sizes:", image_sizes)
|
||||
|
||||
|
||||
print(f"model device: {model.device}")
|
||||
tokenized= tokenized.to(model.device)#sposto tutto sul device del modello
|
||||
|
||||
output = model.generate(
|
||||
**tokenized,
|
||||
image_sizes=image_sizes,
|
||||
max_new_tokens=512,
|
||||
)[0]
|
||||
|
||||
decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
|
||||
print(decoded_output)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
import torch
|
||||
import gc
|
||||
|
||||
cache_dir = "./model_cache"
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model_kwargs = dict(
|
||||
#attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
|
||||
torch_dtype=torch.bfloat16,#torch.float16,#torch.bfloat16, # What torch dtype to use, defaults to auto
|
||||
device_map="auto", # Let torch decide how to load the model
|
||||
|
||||
)
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
|
||||
bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
|
||||
)
|
||||
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen3-VL-2B-Instruct",cache_dir=cache_dir, **model_kwargs
|
||||
)
|
||||
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
# model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
# "Qwen/Qwen3-VL-4B-Instruct",
|
||||
# dtype=torch.bfloat16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# device_map="auto",
|
||||
# )
|
||||
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct",cache_dir=cache_dir)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Preparation for inference
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(model.device) #vedi che non gli passa immagine PIL come fa per gemma3
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,290 @@
|
|||
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
import gc
|
||||
import requests
|
||||
|
||||
# System message for the assistant
|
||||
system_message = "You are a web accessibility evaluation tool. Your task is to evaluate if alterative text for images on webpages are appropriate according to WCAG guidelines."
|
||||
|
||||
# User prompt that combines the user query and the schema
|
||||
user_prompt_old = """Create the most appropriate new alt-text given the image, the <HTML context>, and the current <alt-text>. Keep this within 30 words. Use the same language as the original alt-text.
|
||||
Only return the new alt-text.
|
||||
|
||||
<alt-text>
|
||||
{alttext}
|
||||
</alt-text>
|
||||
|
||||
<HTML context>
|
||||
{HTML_context}
|
||||
</HTML context>
|
||||
|
||||
"""
|
||||
|
||||
### anche con questo user-prompt ridotto rispetto al training riesco a rispettare tipo out e valori riducendo la lunghezza stessa del prompt. Il processo di finetuning "tiene memoria".
|
||||
### collegato a quanto riportatao qua: https://docs.mistral.ai/capabilities/finetuning
|
||||
### "Enhancing the model’s performance by mimicking the behavior of a model with a complex prompt, but without the need for the actual prompt, thereby saving tokens, and reducing associated costs"
|
||||
|
||||
user_prompt= """
|
||||
<alt-text>
|
||||
{alttext}
|
||||
</alt-text>
|
||||
|
||||
<HTML context>
|
||||
{HTML_context}
|
||||
</HTML context>
|
||||
|
||||
Follow these instructions carefully:
|
||||
1. You will be provided as input with the following:
|
||||
- The image found on the webpage.
|
||||
- The associated alternative text. When the alt-text is empty or absent, you will be explicitly informed.
|
||||
- The surrounding context of the image.
|
||||
- The page title, headings and the content of the “keywords” and “description” <meta> tag, if found.
|
||||
|
||||
2. Determine the function and purpose of the image by analyzing these elements. Take into account the purpose and function
|
||||
of the associated image by considering the page context. Check also if the image is, or is associated with, a link or a button,
|
||||
and consider this in your judgement. If the image contains text use that as part of the context.
|
||||
|
||||
3. Provide a final assessment judgment based on the following:
|
||||
- 'success' if you can assess with 'sufficient certainty' the alt-text is appropriate in relation to the image purpose,
|
||||
- 'failure' if you can assess with 'sufficient certainty' that the alt-text is NOT appropriate,
|
||||
- 'warning' if you cannot determine with 'sufficient certainty'.
|
||||
where the level of certainty goes from 1 to 100 and 'sufficient certainty' means > 80
|
||||
|
||||
4. The original alt-text assessment on a scale from 1 to 5, where 5 is the best score. Use an integer number only.
|
||||
|
||||
5. Provide a brief reasoning for your judgment. If the image contains text, write it verbatim.
|
||||
|
||||
6. Keep your response within 150 words.
|
||||
|
||||
7. Generate the new most appropriate alt-text given the context and the steps before. Keep this within 30 words. Use the same natural language (e.g., English, Spanish, Italian) as the original alt-text.
|
||||
|
||||
8. Here is the JSON format the results must have:
|
||||
"Original alt-text assessment" : "*your original alt-text assessment*", "Assessment" : "*your assessment judgment*", "EvaluationResult": "*your response*", "New alt-text":"*new alt-text*"
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
# print("Processing vision info...")
|
||||
image_inputs = []
|
||||
# Iterate through each conversation
|
||||
for msg in messages:
|
||||
# Get content (ensure it's a list)
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
# Check each content element for images
|
||||
for element in content:
|
||||
if isinstance(element, dict) and (
|
||||
"image" in element or element.get("type") == "image"
|
||||
):
|
||||
# Get the image and convert to RGB
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
image_inputs.append(image.convert("RGB")) # converte in rgb !
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(sample, is_inference=False):
|
||||
|
||||
formatted_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_prompt.format(
|
||||
HTML_context=sample["html_context"],
|
||||
alttext=sample["alt_text"],
|
||||
# accessibility_expert_alt_text_assessment=sample["original_alt_text_assessment"],
|
||||
# accessibility_expert_alt_text_comments=sample["evaluation_result"]
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": sample["image"].convert(
|
||||
"RGB"
|
||||
), # .convert("RGB") necessario??
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
if is_inference:
|
||||
print("formatted_data for inference:", formatted_data ) # non gli passo la parte assistant (la risposta attesa) come fa nell'esempio HF
|
||||
|
||||
else:
|
||||
formatted_data["messages"].append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": sample["new_alt_text"]}
|
||||
], # vedi ruolo assistente per la risposta aspettata
|
||||
}
|
||||
)
|
||||
return formatted_data
|
||||
|
||||
|
||||
def generate_description(sample, model, processor):
|
||||
print("Generating description...")
|
||||
# Convert sample into messages and then apply the chat template
|
||||
"""messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": system_message}]},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image","image": sample["image"]},
|
||||
{"type": "text", "text": user_prompt.format(product=sample["product_name"], category=sample["category"])},
|
||||
]},
|
||||
]"""
|
||||
|
||||
### prendo il primo elemento come test
|
||||
# image_inputs=dataset[0]["image"]#non è una lista ma per il resto è uguale a sotto
|
||||
# print("image_inputs_pre:", image_inputs)
|
||||
format_data_example = format_data(sample, is_inference=True)
|
||||
messages = format_data_example["messages"]
|
||||
# print("User message:", messages)
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# Process the image and text
|
||||
image_inputs = process_vision_info(
|
||||
messages
|
||||
) # converte immagine in rgb anche se sembra lo faccia già sopra nel sample .convert("RGB")
|
||||
# print("image_inputs:", image_inputs)
|
||||
|
||||
# Tokenize the text and process the images
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=8192, # Equivalent to num_ctx, max input token
|
||||
truncation=True,
|
||||
)
|
||||
# Move the inputs to the device
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
# Generate the output
|
||||
stop_token_ids = [
|
||||
processor.tokenizer.eos_token_id,
|
||||
processor.tokenizer.convert_tokens_to_ids("<end_of_turn>"),
|
||||
]
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.7, # Same as Ollama
|
||||
"max_new_tokens": 800, # Equivalent to num_predict
|
||||
"top_p": 0.95, # Same as Ollama
|
||||
"do_sample": True, # Required for temperature/top_p to work
|
||||
}
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
**generation_config,
|
||||
eos_token_id=stop_token_ids,
|
||||
disable_compile=True
|
||||
)
|
||||
# Trim the generation and decode the output to text
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :]
|
||||
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
return output_text[0]
|
||||
|
||||
|
||||
|
||||
sample = {
|
||||
"html_context": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur",
|
||||
"alt_text": "Toys & Games | Toy Figures & Playsets | Action Figures",
|
||||
"image": Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw)#.convert("RGB"),
|
||||
|
||||
}
|
||||
|
||||
#output_dir = "./merged_model"
|
||||
output_dir_it = "./merged_model_google_gemma-3-4b-it"
|
||||
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
"""
|
||||
# Load Model with PEFT adapter
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
output_dir,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
print("\n Model loaded #3")
|
||||
processor = AutoProcessor.from_pretrained(output_dir)
|
||||
print("Processor loaded #3")
|
||||
# print(model)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print("testing the Merged model #3 ...")
|
||||
|
||||
|
||||
|
||||
|
||||
# generate the description
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 1:", description)
|
||||
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 2:", description)
|
||||
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 3:", description)
|
||||
|
||||
|
||||
print("Freeing up memory...")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
del model
|
||||
"""
|
||||
|
||||
|
||||
# Load Model with PEFT adapter
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
output_dir_it,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
print("\n Model loaded #3 from it")
|
||||
processor = AutoProcessor.from_pretrained(output_dir_it)
|
||||
print("Processor loaded #3 from it")
|
||||
# print(model)
|
||||
|
||||
|
||||
print("testing the Merged model #3 from it...")
|
||||
|
||||
|
||||
# dataset = [format_data(sample) for sample in dataset]
|
||||
|
||||
# generate the description
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 1:", description)
|
||||
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 2:", description)
|
||||
|
||||
description = generate_description(sample, model, processor)
|
||||
print("-text generated 3:", description)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 119 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 113 KiB |
|
|
@ -0,0 +1,180 @@
|
|||
import json
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
import re
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
|
||||
|
||||
def call_API_urlibrequest(
|
||||
data={},
|
||||
verbose=False,
|
||||
url="",
|
||||
headers=[],
|
||||
method="post",
|
||||
base=2, # number of seconds to wait
|
||||
max_tries=3,
|
||||
):
|
||||
|
||||
if verbose:
|
||||
logging.info("input_data:%s", data)
|
||||
|
||||
# Allow multiple attempts to call the API incase of downtime.
|
||||
# Return provided response to user after 3 failed attempts.
|
||||
wait_seconds = [base**i for i in range(max_tries)]
|
||||
|
||||
for num_tries in range(max_tries):
|
||||
try:
|
||||
|
||||
if method == "get":
|
||||
|
||||
# Encode the parameters and append them to the URL
|
||||
query_string = urllib.parse.urlencode(data)
|
||||
|
||||
url_with_params = f"{url}?{query_string}"
|
||||
request = urllib.request.Request(url_with_params, method="GET")
|
||||
for ele in headers:
|
||||
|
||||
request.add_header(ele[0], ele[1])
|
||||
|
||||
elif method == "post":
|
||||
# Convert the dictionary to a JSON formatted string and encode it to bytes
|
||||
data_to_send = json.dumps(data).encode("utf-8")
|
||||
|
||||
request = urllib.request.Request(url, data=data_to_send, method="POST")
|
||||
for ele in headers:
|
||||
|
||||
request.add_header(ele[0], ele[1])
|
||||
else:
|
||||
return {"error_message": "method_not_allowed"}
|
||||
|
||||
# Send the request and capture the response
|
||||
|
||||
with urllib.request.urlopen(request, timeout=300) as response:
|
||||
# Read and decode the response
|
||||
|
||||
response_json = json.loads(response.read().decode("utf-8"))
|
||||
logging.info("response_json:%s", response_json)
|
||||
|
||||
logging.info("response.status_code:%s", response.getcode())
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logging.error("error message:%s", e)
|
||||
response_json = {"error": e}
|
||||
|
||||
logging.info("num_tries:%s", num_tries)
|
||||
logging.info(
|
||||
"Waiting %s seconds before automatically trying again.",
|
||||
str(wait_seconds[num_tries]),
|
||||
)
|
||||
time.sleep(wait_seconds[num_tries])
|
||||
|
||||
logging.info(
|
||||
"Tried %s times to make API call to get a valid response object", max_tries
|
||||
)
|
||||
logging.info("Returning provided response")
|
||||
return response_json
|
||||
|
||||
|
||||
def parse_mllm_alt_text_response(mllm_response):
|
||||
"""
|
||||
Parse an MLLM response string and extract key attributes into a JSON object.
|
||||
|
||||
from mllm response like:
|
||||
```json\n{\n\"Original alt-text assessment\"... etc
|
||||
to a structured dictionary.
|
||||
|
||||
Args:
|
||||
mllm_response (str): The raw MLLM response text containing JSON data
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the extracted attributes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
# Handle NaN or None values
|
||||
if mllm_response is None or mllm_response == "":
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
# Extract JSON content between ```json and ``` markers
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
# Try to find JSON without markdown code blocks
|
||||
json_match = re.search(r'\{.*\}', mllm_response, re.DOTALL)
|
||||
|
||||
if not json_match:
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
json_str = json_match.group(1) if '```json' in mllm_response else json_match.group(0)
|
||||
|
||||
# Parse the JSON string
|
||||
parsed_data = json.loads(json_str)
|
||||
|
||||
# Create a structured output with the key attributes
|
||||
result = {
|
||||
"original_alt_text_assessment": parsed_data.get("Original alt-text assessment", ""),
|
||||
"assessment": parsed_data.get("Assessment", ""),
|
||||
"evaluation_result": parsed_data.get("EvaluationResult", ""),
|
||||
"new_alt_text": parsed_data.get("New alt-text", "")
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing MLLM response: {e}")
|
||||
return {
|
||||
"original_alt_text_assessment": None,
|
||||
"assessment": None,
|
||||
"evaluation_result": None,
|
||||
"new_alt_text": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def encode_image_from_url(image_url):
|
||||
response = requests.get(image_url)
|
||||
|
||||
# Open image and convert to RGB
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
|
||||
# Convert to RGB (handles RGBA, grayscale, etc.)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Save to bytes buffer
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='PNG') # or 'JPEG'
|
||||
buffer.seek(0)
|
||||
|
||||
# Encode to base64
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
import re
|
||||
from collections import Counter
|
||||
|
||||
"""
|
||||
For English texts:
|
||||
|
||||
Flesch Reading Ease score
|
||||
Flesch-Kincaid Grade Level
|
||||
Gunning Fog Index
|
||||
|
||||
For Italian texts:
|
||||
|
||||
Flesch Reading Ease (adapted for Italian with Flesch-Vacca formula)
|
||||
Gulpease Index (specifically designed for Italian)
|
||||
Gunning Fog Index
|
||||
|
||||
Basic statistics for both:
|
||||
|
||||
Sentence count
|
||||
Word count
|
||||
Syllable count
|
||||
Complex words (3+ syllables)
|
||||
Average words per sentence
|
||||
Average syllables per word
|
||||
"""
|
||||
|
||||
class ReadabilityAnalyzer:
|
||||
"""Analyze text readability for English and Italian"""
|
||||
|
||||
def __init__(self, text, language='en'):
|
||||
self.text = text
|
||||
self.language = language.lower()
|
||||
self.sentences = self._count_sentences()
|
||||
self.words = self._count_words()
|
||||
self.syllables = self._count_syllables()
|
||||
self.complex_words = self._count_complex_words()
|
||||
self.characters = len(re.sub(r'\s', '', text))
|
||||
|
||||
def _count_sentences(self):
|
||||
"""Count sentences in text"""
|
||||
sentences = re.split(r'[.!?]+', self.text)
|
||||
return len([s for s in sentences if s.strip()])
|
||||
|
||||
def _count_words(self):
|
||||
"""Count words in text"""
|
||||
words = re.findall(r'\b[a-zA-ZàèéìòùÀÈÉÌÒÙáíóúýÁÍÓÚÝâêîôûÂÊÎÔÛäëïöüÄËÏÖÜ]+\b', self.text)
|
||||
return len(words)
|
||||
|
||||
def _count_syllables(self):
|
||||
"""Count syllables in text (approximation for both languages)"""
|
||||
words = re.findall(r'\b[a-zA-ZàèéìòùÀÈÉÌÒÙáíóúýÁÍÓÚÝâêîôûÂÊÎÔÛäëïöüÄËÏÖÜ]+\b', self.text.lower())
|
||||
total_syllables = 0
|
||||
|
||||
for word in words:
|
||||
if self.language == 'it':
|
||||
syllables = self._count_syllables_italian(word)
|
||||
else:
|
||||
syllables = self._count_syllables_english(word)
|
||||
total_syllables += syllables
|
||||
|
||||
return total_syllables
|
||||
|
||||
def _count_syllables_english(self, word):
|
||||
"""Count syllables in English word"""
|
||||
word = word.lower()
|
||||
vowels = 'aeiouy'
|
||||
syllables = 0
|
||||
previous_was_vowel = False
|
||||
|
||||
for char in word:
|
||||
is_vowel = char in vowels
|
||||
if is_vowel and not previous_was_vowel:
|
||||
syllables += 1
|
||||
previous_was_vowel = is_vowel
|
||||
|
||||
# Adjust for silent e
|
||||
if word.endswith('e'):
|
||||
syllables -= 1
|
||||
|
||||
# Ensure at least 1 syllable
|
||||
if syllables == 0:
|
||||
syllables = 1
|
||||
|
||||
return syllables
|
||||
|
||||
def _count_syllables_italian(self, word):
|
||||
"""Count syllables in Italian word"""
|
||||
word = word.lower()
|
||||
vowels = 'aeiouàèéìòùáíóúý'
|
||||
syllables = 0
|
||||
previous_was_vowel = False
|
||||
|
||||
for char in word:
|
||||
is_vowel = char in vowels
|
||||
if is_vowel and not previous_was_vowel:
|
||||
syllables += 1
|
||||
previous_was_vowel = is_vowel
|
||||
|
||||
# Ensure at least 1 syllable
|
||||
if syllables == 0:
|
||||
syllables = 1
|
||||
|
||||
return syllables
|
||||
|
||||
def _count_complex_words(self):
|
||||
"""Count words with 3+ syllables"""
|
||||
words = re.findall(r'\b[a-zA-ZàèéìòùÀÈÉÌÒÙáíóúýÁÍÓÚÝâêîôûÂÊÎÔÛäëïöüÄËÏÖÜ]+\b', self.text.lower())
|
||||
complex_count = 0
|
||||
|
||||
for word in words:
|
||||
if self.language == 'it':
|
||||
syllables = self._count_syllables_italian(word)
|
||||
else:
|
||||
syllables = self._count_syllables_english(word)
|
||||
|
||||
if syllables >= 3:
|
||||
complex_count += 1
|
||||
|
||||
return complex_count
|
||||
|
||||
def flesch_reading_ease(self):
|
||||
"""Calculate Flesch Reading Ease score"""
|
||||
if self.words == 0 or self.sentences == 0:
|
||||
return 0
|
||||
|
||||
if self.language == 'it':
|
||||
# Flesch-Vacca formula for Italian
|
||||
score = 206.835 - 1.3 * (self.words / self.sentences) - 60.1 * (self.syllables / self.words)
|
||||
else:
|
||||
# Standard Flesch formula for English
|
||||
score = 206.835 - 1.015 * (self.words / self.sentences) - 84.6 * (self.syllables / self.words)
|
||||
|
||||
return round(score, 2)
|
||||
|
||||
def flesch_kincaid_grade(self):
|
||||
"""Calculate Flesch-Kincaid Grade Level (primarily for English)"""
|
||||
if self.words == 0 or self.sentences == 0:
|
||||
return 0
|
||||
|
||||
grade = 0.39 * (self.words / self.sentences) + 11.8 * (self.syllables / self.words) - 15.59
|
||||
return round(grade, 2)
|
||||
|
||||
def gunning_fog_index(self):
|
||||
"""Calculate Gunning Fog Index"""
|
||||
if self.words == 0 or self.sentences == 0:
|
||||
return 0
|
||||
|
||||
fog = 0.4 * ((self.words / self.sentences) + 100 * (self.complex_words / self.words))
|
||||
return round(fog, 2)
|
||||
|
||||
def gulpease_index(self):
|
||||
"""Calculate Gulpease Index (for Italian)"""
|
||||
if self.words == 0:
|
||||
return 0
|
||||
|
||||
gulpease = 89 - (self.characters / self.words * 10) + (self.sentences / self.words * 300)
|
||||
return round(gulpease, 2)
|
||||
|
||||
def get_all_scores(self):
|
||||
"""Get all readability scores"""
|
||||
scores = {
|
||||
'basic_stats': {
|
||||
'sentences': self.sentences,
|
||||
'words': self.words,
|
||||
'syllables': self.syllables,
|
||||
'complex_words': self.complex_words,
|
||||
'characters': self.characters,
|
||||
'avg_words_per_sentence': round(self.words / self.sentences, 2) if self.sentences > 0 else 0,
|
||||
'avg_syllables_per_word': round(self.syllables / self.words, 2) if self.words > 0 else 0
|
||||
},
|
||||
'readability_scores': {}
|
||||
}
|
||||
|
||||
# Add appropriate scores based on language
|
||||
if self.language == 'it':
|
||||
scores['readability_scores']['flesch_reading_ease_it'] = self.flesch_reading_ease()
|
||||
scores['readability_scores']['gulpease_index'] = self.gulpease_index()
|
||||
scores['readability_scores']['gunning_fog_index'] = self.gunning_fog_index()
|
||||
else:
|
||||
scores['readability_scores']['flesch_reading_ease'] = self.flesch_reading_ease()
|
||||
scores['readability_scores']['flesch_kincaid_grade'] = self.flesch_kincaid_grade()
|
||||
scores['readability_scores']['gunning_fog_index'] = self.gunning_fog_index()
|
||||
|
||||
return scores
|
||||
|
||||
def interpret_scores(self):
|
||||
"""Provide interpretation of readability scores"""
|
||||
scores = self.get_all_scores()
|
||||
interpretation = []
|
||||
|
||||
if self.language == 'it':
|
||||
# Flesch Reading Ease (Italian)
|
||||
fre = scores['readability_scores']['flesch_reading_ease_it']
|
||||
if fre >= 80:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Molto facile (Very easy)")
|
||||
elif fre >= 60:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Facile (Easy)")
|
||||
elif fre >= 50:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Abbastanza facile (Fairly easy)")
|
||||
elif fre >= 40:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Normale (Normal)")
|
||||
elif fre >= 30:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Abbastanza difficile (Fairly difficult)")
|
||||
else:
|
||||
interpretation.append(f"Flesch Reading Ease (IT): {fre} - Difficile (Difficult)")
|
||||
|
||||
# Gulpease Index
|
||||
gulpease = scores['readability_scores']['gulpease_index']
|
||||
if gulpease >= 80:
|
||||
interpretation.append(f"Gulpease Index: {gulpease} - Elementare (Elementary school)")
|
||||
elif gulpease >= 60:
|
||||
interpretation.append(f"Gulpease Index: {gulpease} - Media inferiore (Middle school)")
|
||||
elif gulpease >= 40:
|
||||
interpretation.append(f"Gulpease Index: {gulpease} - Media superiore (High school)")
|
||||
else:
|
||||
interpretation.append(f"Gulpease Index: {gulpease} - Universitario (University)")
|
||||
else:
|
||||
# Flesch Reading Ease (English)
|
||||
fre = scores['readability_scores']['flesch_reading_ease']
|
||||
if fre >= 90:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Very easy (5th grade)")
|
||||
elif fre >= 80:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Easy (6th grade)")
|
||||
elif fre >= 70:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Fairly easy (7th grade)")
|
||||
elif fre >= 60:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Standard (8th-9th grade)")
|
||||
elif fre >= 50:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Fairly difficult (10th-12th grade)")
|
||||
elif fre >= 30:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Difficult (College)")
|
||||
else:
|
||||
interpretation.append(f"Flesch Reading Ease: {fre} - Very difficult (College graduate)")
|
||||
|
||||
# Flesch-Kincaid Grade
|
||||
fkg = scores['readability_scores']['flesch_kincaid_grade']
|
||||
interpretation.append(f"Flesch-Kincaid Grade: {fkg} (US grade level)")
|
||||
|
||||
# Gunning Fog Index (both languages)
|
||||
fog = scores['readability_scores']['gunning_fog_index']
|
||||
interpretation.append(f"Gunning Fog Index: {fog} (years of education needed)")
|
||||
|
||||
return '\n'.join(interpretation)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# English example
|
||||
english_text = """
|
||||
The quick brown fox jumps over the lazy dog. This is a simple sentence.
|
||||
However, more complicated sentences with multisyllabic words can significantly
|
||||
increase the complexity of the text and make it harder to read.
|
||||
"""
|
||||
|
||||
print("=== ENGLISH TEXT ANALYSIS ===")
|
||||
analyzer_en = ReadabilityAnalyzer(english_text, language='en')
|
||||
scores_en = analyzer_en.get_all_scores()
|
||||
|
||||
print("\nBasic Statistics:")
|
||||
for key, value in scores_en['basic_stats'].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nReadability Scores:")
|
||||
for key, value in scores_en['readability_scores'].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nInterpretation:")
|
||||
print(analyzer_en.interpret_scores())
|
||||
|
||||
# Italian example
|
||||
italian_text = """
|
||||
Il veloce cane marrone salta sopra il cane pigro. Questa è una frase semplice.
|
||||
Tuttavia, frasi più complicate con parole polisillabiche possono aumentare
|
||||
significativamente la complessità del testo e renderlo più difficile da leggere.
|
||||
"""
|
||||
|
||||
print("\n\n=== ITALIAN TEXT ANALYSIS ===")
|
||||
analyzer_it = ReadabilityAnalyzer(italian_text, language='it')
|
||||
scores_it = analyzer_it.get_all_scores()
|
||||
|
||||
print("\nBasic Statistics:")
|
||||
for key, value in scores_it['basic_stats'].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nReadability Scores:")
|
||||
for key, value in scores_it['readability_scores'].items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nInterpretation:")
|
||||
print(analyzer_it.interpret_scores())
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
import numpy as np
|
||||
from transformers import BertTokenizer, BertModel
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
from bert_score import score
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
def semantic_similarity(text1, text2):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
model = BertModel.from_pretrained("bert-base-uncased")
|
||||
|
||||
inputs1 = tokenizer(text1, return_tensors="pt")
|
||||
inputs2 = tokenizer(text2, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs1 = model(**inputs1)
|
||||
outputs2 = model(**inputs2)
|
||||
|
||||
embedding1 = outputs1.last_hidden_state.mean(dim=1).squeeze().numpy()
|
||||
embedding2 = outputs2.last_hidden_state.mean(dim=1).squeeze().numpy()
|
||||
|
||||
return cosine_similarity(embedding1, embedding2)
|
||||
|
||||
|
||||
def semantic_similarity_sentence_transformer(text1, text2):
|
||||
# Purpose-built for sentence embeddings
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
embeddings = model.encode([text1, text2], output_value="sentence_embedding")
|
||||
return cosine_similarity(embeddings[0], embeddings[1])
|
||||
|
||||
|
||||
def lexical_similarity(text1, text2):
|
||||
vectorizer = TfidfVectorizer(stop_words=None, analyzer="char", ngram_range=(1, 3))
|
||||
tfidf_matrix = vectorizer.fit_transform([text1, text2])
|
||||
vec1 = tfidf_matrix.toarray()[0]
|
||||
vec2 = tfidf_matrix.toarray()[1]
|
||||
return cosine_similarity(vec1, vec2)
|
||||
|
||||
|
||||
def bert_score_similarity(texts1, texts2, batch=False):
|
||||
P, R, F1 = score(
|
||||
texts1,
|
||||
texts2,
|
||||
lang="en",
|
||||
verbose=False,
|
||||
model_type="bert-base-uncased",
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
batch_size=32,
|
||||
)
|
||||
return F1.tolist() if batch else F1.item()
|
||||
Loading…
Reference in New Issue