225 lines
7.1 KiB
Python
225 lines
7.1 KiB
Python
import os
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from sklearn.decomposition import TruncatedSVD
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.preprocessing import normalize
|
|
from torch.optim import AdamW
|
|
|
|
from evaluation.evaluate import evaluate, log_eval
|
|
|
|
|
|
def _normalize(lX, l2=True):
|
|
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
|
|
|
|
|
def XdotM(X, M, sif):
|
|
E = X.dot(M)
|
|
if sif:
|
|
E = remove_pc(E, npc=1)
|
|
return E
|
|
|
|
|
|
def remove_pc(X, npc=1):
|
|
"""
|
|
Remove the projection on the principal components
|
|
:param X: X[i,:] is a data point
|
|
:param npc: number of principal components to remove
|
|
:return: XX[i, :] is the data point after removing its projection
|
|
"""
|
|
pc = compute_pc(X, npc)
|
|
if npc == 1:
|
|
XX = X - X.dot(pc.transpose()) * pc
|
|
else:
|
|
XX = X - X.dot(pc.transpose()).dot(pc)
|
|
return XX
|
|
|
|
|
|
def compute_pc(X, npc=1):
|
|
"""
|
|
Compute the principal components.
|
|
:param X: X[i,:] is a data point
|
|
:param npc: number of principal components to remove
|
|
:return: component_[i,:] is the i-th pc
|
|
"""
|
|
if isinstance(X, np.matrix):
|
|
X = np.asarray(X)
|
|
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
|
|
svd.fit(X)
|
|
return svd.components_
|
|
|
|
|
|
def predict(logits, classification_type="multilabel"):
|
|
"""
|
|
Converts soft precictions to hard predictions [0,1]
|
|
"""
|
|
if classification_type == "multilabel":
|
|
prediction = torch.sigmoid(logits) > 0.5
|
|
elif classification_type == "singlelabel":
|
|
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
|
else:
|
|
print("unknown classification type")
|
|
|
|
return prediction.detach().cpu().numpy()
|
|
|
|
|
|
class TfidfVectorizerMultilingual:
|
|
def __init__(self, **kwargs):
|
|
self.kwargs = kwargs
|
|
|
|
def fit(self, lX, ly=None):
|
|
self.langs = sorted(lX.keys())
|
|
self.vectorizer = {
|
|
l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs
|
|
}
|
|
return self
|
|
|
|
def transform(self, lX):
|
|
return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs}
|
|
|
|
def fit_transform(self, lX, ly=None):
|
|
return self.fit(lX, ly).transform(lX)
|
|
|
|
def vocabulary(self, l=None):
|
|
if l is None:
|
|
return {l: self.vectorizer[l].vocabulary_ for l in self.langs}
|
|
else:
|
|
return self.vectorizer[l].vocabulary_
|
|
|
|
def get_analyzer(self, l=None):
|
|
if l is None:
|
|
return {l: self.vectorizer[l].build_analyzer() for l in self.langs}
|
|
else:
|
|
return self.vectorizer[l].build_analyzer()
|
|
|
|
|
|
class Trainer:
|
|
def __init__(
|
|
self,
|
|
model,
|
|
optimizer_name,
|
|
device,
|
|
loss_fn,
|
|
lr,
|
|
print_steps,
|
|
evaluate_step,
|
|
patience,
|
|
experiment_name,
|
|
):
|
|
self.device = device
|
|
self.model = model.to(device)
|
|
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
|
self.evaluate_steps = evaluate_step
|
|
self.loss_fn = loss_fn.to(device)
|
|
self.print_steps = print_steps
|
|
self.earlystopping = EarlyStopping(
|
|
patience=patience,
|
|
checkpoint_path="models/vgfs/transformers/",
|
|
verbose=True,
|
|
experiment_name=experiment_name,
|
|
)
|
|
|
|
def init_optimizer(self, optimizer_name, lr):
|
|
if optimizer_name.lower() == "adamw":
|
|
return AdamW(self.model.parameters(), lr=lr)
|
|
else:
|
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
|
|
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
|
print(
|
|
f"""- Training params:
|
|
- epochs: {epochs}
|
|
- learning rate: {self.optimizer.defaults['lr']}
|
|
- train batch size: {train_dataloader.batch_size}
|
|
- eval batch size: {eval_dataloader.batch_size}
|
|
- max len: {train_dataloader.dataset.X.shape[-1]}\n""",
|
|
)
|
|
for epoch in range(epochs):
|
|
self.train_epoch(train_dataloader, epoch)
|
|
if (epoch + 1) % self.evaluate_steps == 0:
|
|
metric_watcher = self.evaluate(eval_dataloader)
|
|
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
|
if stop:
|
|
break
|
|
return self.model
|
|
|
|
def train_epoch(self, dataloader, epoch):
|
|
self.model.train()
|
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
self.optimizer.zero_grad()
|
|
y_hat = self.model(x.to(self.device))
|
|
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
if b_idx % self.print_steps == 0:
|
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
|
return self
|
|
|
|
def evaluate(self, dataloader):
|
|
self.model.eval()
|
|
|
|
lY = defaultdict(list)
|
|
lY_hat = defaultdict(list)
|
|
|
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
|
y_hat = self.model(x.to(self.device))
|
|
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
|
predictions = predict(y_hat.logits, classification_type="multilabel")
|
|
|
|
for l, _true, _pred in zip(lang, y, predictions):
|
|
lY[l].append(_true.detach().cpu().numpy())
|
|
lY_hat[l].append(_pred)
|
|
|
|
for lang in lY:
|
|
lY[lang] = np.vstack(lY[lang])
|
|
lY_hat[lang] = np.vstack(lY_hat[lang])
|
|
|
|
l_eval = evaluate(lY, lY_hat)
|
|
average_metrics = log_eval(l_eval, phase="validation")
|
|
return average_metrics[0] # macro-F1
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(
|
|
self,
|
|
patience=5,
|
|
min_delta=0,
|
|
verbose=True,
|
|
checkpoint_path="checkpoint.pt",
|
|
experiment_name="experiment",
|
|
):
|
|
self.patience = patience
|
|
self.min_delta = min_delta
|
|
self.counter = 0
|
|
self.best_score = 0
|
|
self.best_epoch = None
|
|
self.verbose = verbose
|
|
self.checkpoint_path = checkpoint_path
|
|
self.experiment_name = experiment_name
|
|
|
|
def __call__(self, validation, model, epoch):
|
|
if validation > self.best_score:
|
|
print(
|
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
|
)
|
|
self.best_score = validation
|
|
self.counter = 0
|
|
# self.save_model(model)
|
|
elif validation < (self.best_score + self.min_delta):
|
|
self.counter += 1
|
|
print(
|
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
|
)
|
|
if self.counter >= self.patience:
|
|
if self.verbose:
|
|
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
|
return True
|
|
|
|
def save_model(self, model):
|
|
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
|
|
print(f"- saving model to {_checkpoint_dir}")
|
|
os.makedirs(_checkpoint_dir, exist_ok=True)
|
|
model.save_pretrained(_checkpoint_dir)
|