import os from collections import defaultdict import numpy as np import torch from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import normalize from torch.optim import AdamW from evaluation.evaluate import evaluate, log_eval def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX def XdotM(X, M, sif): E = X.dot(M) if sif: E = remove_pc(E, npc=1) return E def remove_pc(X, npc=1): """ Remove the projection on the principal components :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: XX[i, :] is the data point after removing its projection """ pc = compute_pc(X, npc) if npc == 1: XX = X - X.dot(pc.transpose()) * pc else: XX = X - X.dot(pc.transpose()).dot(pc) return XX def compute_pc(X, npc=1): """ Compute the principal components. :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: component_[i,:] is the i-th pc """ if isinstance(X, np.matrix): X = np.asarray(X) svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) svd.fit(X) return svd.components_ def predict(logits, classification_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ if classification_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 elif classification_type == "singlelabel": prediction = torch.argmax(logits, dim=1).view(-1, 1) else: print("unknown classification type") return prediction.detach().cpu().numpy() class TfidfVectorizerMultilingual: def __init__(self, **kwargs): self.kwargs = kwargs def fit(self, lX, ly=None): self.langs = sorted(lX.keys()) self.vectorizer = { l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs } return self def transform(self, lX): return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs} def fit_transform(self, lX, ly=None): return self.fit(lX, ly).transform(lX) def vocabulary(self, l=None): if l is None: return {l: self.vectorizer[l].vocabulary_ for l in self.langs} else: return self.vectorizer[l].vocabulary_ def get_analyzer(self, l=None): if l is None: return {l: self.vectorizer[l].build_analyzer() for l in self.langs} else: return self.vectorizer[l].build_analyzer() class Trainer: def __init__( self, model, optimizer_name, device, loss_fn, lr, print_steps, evaluate_step, patience, experiment_name, ): self.device = device self.model = model.to(device) self.optimizer = self.init_optimizer(optimizer_name, lr) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.earlystopping = EarlyStopping( patience=patience, checkpoint_path="models/vgfs/transformers/", verbose=True, experiment_name=experiment_name, ) def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": return AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): print( f"""- Training params: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]}\n""", ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: metric_watcher = self.evaluate(eval_dataloader) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: break return self.model def train_epoch(self, dataloader, epoch): self.model.train() for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss.backward() self.optimizer.step() if b_idx % self.print_steps == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader): self.model.eval() lY = defaultdict(list) lY_hat = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) lY_hat[l].append(_pred) for lang in lY: lY[lang] = np.vstack(lY[lang]) lY_hat[lang] = np.vstack(lY_hat[lang]) l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation") return average_metrics[0] # macro-F1 class EarlyStopping: def __init__( self, patience=5, min_delta=0, verbose=True, checkpoint_path="checkpoint.pt", experiment_name="experiment", ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = 0 self.best_epoch = None self.verbose = verbose self.checkpoint_path = checkpoint_path self.experiment_name = experiment_name def __call__(self, validation, model, epoch): if validation > self.best_score: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" ) self.best_score = validation self.counter = 0 # self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) if self.counter >= self.patience: if self.verbose: print(f"- earlystopping: Early stopping at epoch {epoch}") return True def save_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) print(f"- saving model to {_checkpoint_dir}") os.makedirs(_checkpoint_dir, exist_ok=True) model.save_pretrained(_checkpoint_dir)