import os from collections import defaultdict import numpy as np import torch import torch.nn as nn from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.model_selection import train_test_split from sklearn.preprocessing import normalize from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers.modeling_outputs import ModelOutput from evaluation.evaluate import evaluate, log_eval PRINT_ON_EPOCH = 1 def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX def XdotM(X, M, sif): E = X.dot(M) if sif: E = remove_pc(E, npc=1) return E def remove_pc(X, npc=1): """ Remove the projection on the principal components :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: XX[i, :] is the data point after removing its projection """ pc = compute_pc(X, npc) if npc == 1: XX = X - X.dot(pc.transpose()) * pc else: XX = X - X.dot(pc.transpose()).dot(pc) return XX def compute_pc(X, npc=1): """ Compute the principal components. :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: component_[i,:] is the i-th pc """ if isinstance(X, np.matrix): X = np.asarray(X) svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) svd.fit(X) return svd.components_ def predict(logits, classification_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ if classification_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 elif classification_type == "singlelabel": prediction = torch.argmax(logits, dim=1).view(-1, 1) else: print("unknown classification type") return prediction.detach().cpu().numpy() class TfidfVectorizerMultilingual: def __init__(self, **kwargs): self.kwargs = kwargs def fit(self, lX, ly=None): self.langs = sorted(lX.keys()) self.vectorizer = { l: TfidfVectorizer(**self.kwargs).fit(lX[l]["text"]) for l in self.langs } return self def transform(self, lX): return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} def fit_transform(self, lX, ly=None): return self.fit(lX, ly).transform(lX) def vocabulary(self, l=None): if l is None: return {l: self.vectorizer[l].vocabulary_ for l in self.langs} else: return self.vectorizer[l].vocabulary_ def get_analyzer(self, l=None): if l is None: return {l: self.vectorizer[l].build_analyzer() for l in self.langs} else: return self.vectorizer[l].build_analyzer() class Trainer: def __init__( self, model, optimizer_name, device, loss_fn, lr, print_steps, evaluate_step, patience, experiment_name, checkpoint_path, ): self.device = device self.model = model.to(device) self.optimizer = self.init_optimizer(optimizer_name, lr) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.experiment_name = experiment_name self.patience = patience self.print_eval = evaluate_step self.earlystopping = EarlyStopping( patience=patience, checkpoint_path=checkpoint_path, verbose=False, experiment_name=experiment_name, ) def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": return AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): print( f"""- Training params for {self.experiment_name}: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]} - patience: {self.earlystopping.patience} - evaluate every: {self.evaluate_steps} - print eval every: {self.print_eval} - print train steps: {self.print_steps}\n""" ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: print_eval = (epoch + 1) % self.print_eval == 0 metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: print( f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}" ) self.model = self.earlystopping.load_model(self.model).to( self.device ) break print(f"- last swipe on eval set") self.train_epoch(eval_dataloader, epoch=0) self.earlystopping.save_model(self.model) return self.model def train_epoch(self, dataloader, epoch): self.model.train() for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) if isinstance(y_hat, ModelOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) else: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader, print_eval=True): self.model.eval() lY = defaultdict(list) lY_hat = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) if isinstance(y_hat, ModelOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") else: loss = self.loss_fn(y_hat, y.to(self.device)) predictions = predict(y_hat, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) lY_hat[l].append(_pred) for lang in lY: lY[lang] = np.vstack(lY[lang]) lY_hat[lang] = np.vstack(lY_hat[lang]) l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval) return average_metrics[0] # macro-F1 class EarlyStopping: def __init__( self, patience, checkpoint_path, experiment_name, min_delta=0, verbose=True, ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = 0 self.best_epoch = None self.verbose = verbose self.checkpoint_path = checkpoint_path self.experiment_name = experiment_name def __call__(self, validation, model, epoch): if validation > self.best_score: if self.verbose: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" ) self.best_score = validation self.counter = 0 self.best_epoch = epoch # print(f"- earlystopping: Saving best model from epoch {epoch}") self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 if self.verbose: print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) if self.counter >= self.patience: print(f"- earlystopping: Early stopping at epoch {epoch}") return True def save_model(self, model): os.makedirs(self.checkpoint_path, exist_ok=True) _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) model.save_pretrained(_checkpoint_dir) def load_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) return model.from_pretrained(_checkpoint_dir) class AttentionModule(nn.Module): def __init__(self, embed_dim, num_heads, h_dim, out_dim, aggfunc_type): """We are calling sigmoid on the evaluation loop (Trainer.evaluate), so we are not applying explicitly here at training time. However, we should explcitly squash outputs through the sigmoid at inference (self.transform) (???) """ super().__init__() self.aggfunc = aggfunc_type self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1) # self.layer_norm = nn.LayerNorm(embed_dim) if self.aggfunc == "concat": self.linear = nn.Linear(embed_dim, out_dim) self.sigmoid = nn.Sigmoid() def init_weights(self, mode="mean"): # TODO: add init function of the attention module: either all weights are positive or set to 1/num_classes raise NotImplementedError def __call__(self, X): out, attn_weights = self.attn(query=X, key=X, value=X) # out = self.layer_norm(out) if self.aggfunc == "concat": out = self.linear(out) # out = self.sigmoid(out) return out def transform(self, X): """explicitly calling sigmoid at inference time""" out, attn_weights = self.attn(query=X, key=X, value=X) out = self.sigmoid(out) return out def save_pretrained(self, path): torch.save(self, f"{path}.pt") def from_pretrained(self, path): return torch.load(f"{path}.pt") class AttentionAggregator: def __init__( self, embed_dim, out_dim, epochs, lr, patience, attn_stacking_type, h_dim=512, num_heads=1, device="cpu", ): self.embed_dim = embed_dim self.h_dim = h_dim self.out_dim = out_dim self.patience = patience self.num_heads = num_heads self.device = device self.epochs = epochs self.lr = lr self.stacking_type = attn_stacking_type self.tr_batch_size = 512 self.eval_batch_size = 1024 self.attn = AttentionModule( self.embed_dim, self.num_heads, self.h_dim, self.out_dim, aggfunc_type=self.stacking_type, ).to(self.device) def fit(self, X, Y): print("- fitting Attention-based aggregating function") hstacked_X = self.stack(X) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( hstacked_X, Y, split=0.2, seed=42 ) tra_dataloader = DataLoader( AggregatorDatasetTorch(tr_lX, tr_lY, split="train"), batch_size=self.tr_batch_size, shuffle=True, ) eval_dataloader = DataLoader( AggregatorDatasetTorch(val_lX, val_lY, split="eval"), batch_size=self.eval_batch_size, shuffle=False, ) experiment_name = "attention_aggregator" trainer = Trainer( self.attn, optimizer_name="adamW", lr=self.lr, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=25, evaluate_step=10, patience=self.patience, experiment_name=experiment_name, device=self.device, checkpoint_path="models/aggregator", ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=eval_dataloader, epochs=self.epochs, ) return self def transform(self, X): hstacked_X = self.stack(X) dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole") dataloader = DataLoader(dataset, batch_size=32, shuffle=False) _embeds = [] l_embeds = defaultdict(list) self.attn.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.attn.transform(input_ids) _embeds.append((out.cpu().numpy(), lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def stack(self, data): if self.stacking_type == "concat": hstack = self._concat_stack(data) elif self.stacking_type == "mean": hstack = self._mean_stack(data) return hstack def _concat_stack(self, data): _langs = data[0].keys() l_projections = {} for l in _langs: l_projections[l] = torch.tensor( np.hstack([view[l] for view in data]), dtype=torch.float32 ) return l_projections def _mean_stack(self, data): # TODO: double check this mess aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()} for lang_projections in data: for lang, projection in lang_projections.items(): aggregated[lang] += projection for lang, projection in aggregated.items(): aggregated[lang] = (aggregated[lang] / len(data)).float() return aggregated def get_train_val_data(self, lX, lY, split=0.2, seed=42): tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y val_lX[lang] = val_X val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY class AggregatorDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]