53 KiB
53 KiB
In [1]:
import json
import matplotlib.pyplot as plt
try:
# Load the trainer state from the JSON file
with open('C:\cartella_condivisa\MachineLearning\HIISlab\\accessibility\\notebook_miei\LLM_accessibility_validator\scripts\modello_finetunato\gemma-finetuned-wcag_google_gemma-3-4b-it\checkpoint-386\\trainer_state.json', 'r') as f:
trainer_state = json.load(f)
# Access the log history
log_history = trainer_state['log_history']
# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]
# Plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.grid(True)
plt.savefig("training_validation_loss.png", dpi=300, bbox_inches='tight')
print("Plot saved successfully as 'training_validation_loss.png'")
except FileNotFoundError:
print("Error: trainer_state.json file not found in the current directory")
except json.JSONDecodeError:
print("Error: Invalid JSON format in trainer_state.json")
except KeyError as e:
print(f"Error: Missing key in trainer_state.json: {e}")
except Exception as e:
print(f"Error plotting loss curves: {e}")
In [3]:
train_losses = [log["loss"] for log in log_history if "loss" in log]
train_losses
Out[3]:
In [2]:
train_losses = [log["loss"] for log in log_history if "loss" in log]
train_losses
Out[2]: