import os from collections import defaultdict import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import normalize from torch.optim import AdamW from transformers.modeling_outputs import SequenceClassifierOutput from evaluation.evaluate import evaluate, log_eval PRINT_ON_EPOCH = 10 def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX def XdotM(X, M, sif): E = X.dot(M) if sif: E = remove_pc(E, npc=1) return E def remove_pc(X, npc=1): """ Remove the projection on the principal components :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: XX[i, :] is the data point after removing its projection """ pc = compute_pc(X, npc) if npc == 1: XX = X - X.dot(pc.transpose()) * pc else: XX = X - X.dot(pc.transpose()).dot(pc) return XX def compute_pc(X, npc=1): """ Compute the principal components. :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: component_[i,:] is the i-th pc """ if isinstance(X, np.matrix): X = np.asarray(X) svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) svd.fit(X) return svd.components_ def predict(logits, classification_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ if classification_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 elif classification_type == "singlelabel": prediction = torch.argmax(logits, dim=1).view(-1, 1) else: print("unknown classification type") return prediction.detach().cpu().numpy() class TfidfVectorizerMultilingual: def __init__(self, **kwargs): self.kwargs = kwargs def fit(self, lX, ly=None): self.langs = sorted(lX.keys()) self.vectorizer = { l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs } return self def transform(self, lX): return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs} def fit_transform(self, lX, ly=None): return self.fit(lX, ly).transform(lX) def vocabulary(self, l=None): if l is None: return {l: self.vectorizer[l].vocabulary_ for l in self.langs} else: return self.vectorizer[l].vocabulary_ def get_analyzer(self, l=None): if l is None: return {l: self.vectorizer[l].build_analyzer() for l in self.langs} else: return self.vectorizer[l].build_analyzer() class Trainer: def __init__( self, model, optimizer_name, device, loss_fn, lr, print_steps, evaluate_step, patience, experiment_name, checkpoint_path, ): self.device = device self.model = model.to(device) self.optimizer = self.init_optimizer(optimizer_name, lr) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.experiment_name = experiment_name self.patience = patience self.earlystopping = EarlyStopping( patience=patience, checkpoint_path=checkpoint_path, verbose=True, experiment_name=experiment_name, ) def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": return AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): print( f"""- Training params for {self.experiment_name}: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]} - patience: {self.earlystopping.patience}\n""" ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: metric_watcher = self.evaluate(eval_dataloader) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: print( f"- restoring best model from epoch {self.earlystopping.best_epoch}" ) self.model = self.earlystopping.load_model(self.model).to( self.device ) break # TODO: maybe a lower lr? self.train_epoch(eval_dataloader, epoch=epoch) print(f"\n- last swipe on eval set") self.earlystopping.save_model(self.model) return self.model def train_epoch(self, dataloader, epoch): self.model.train() for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) if isinstance(y_hat, SequenceClassifierOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) else: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() if (epoch + 1) % PRINT_ON_EPOCH == 0: if b_idx % self.print_steps == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader): self.model.eval() lY = defaultdict(list) lY_hat = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) if isinstance(y_hat, SequenceClassifierOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") else: loss = self.loss_fn(y_hat, y.to(self.device)) predictions = predict(y_hat, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) lY_hat[l].append(_pred) for lang in lY: lY[lang] = np.vstack(lY[lang]) lY_hat[lang] = np.vstack(lY_hat[lang]) l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation") return average_metrics[0] # macro-F1 class EarlyStopping: # TODO: add checkpointing + restore model if early stopping + last swipe on validation set def __init__( self, patience, checkpoint_path, experiment_name, min_delta=0, verbose=True, ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = 0 self.best_epoch = None self.verbose = verbose self.checkpoint_path = checkpoint_path self.experiment_name = experiment_name def __call__(self, validation, model, epoch): if validation > self.best_score: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" ) self.best_score = validation self.counter = 0 self.best_epoch = epoch self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) if self.counter >= self.patience: if self.verbose: print(f"- earlystopping: Early stopping at epoch {epoch}") return True def save_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) os.makedirs(_checkpoint_dir, exist_ok=True) model.save_pretrained(_checkpoint_dir) def load_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) return model.from_pretrained(_checkpoint_dir) class AttentionModule(nn.Module): def __init__(self, embed_dim, num_heads, out_dim): super().__init__() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.linear = nn.Linear(embed_dim, out_dim) def __call__(self, X): attn_out, attn_weights = self.attn(query=X, key=X, value=X) out = self.linear(attn_out) return out def transform(self, X): attn_out, attn_weights = self.attn(query=X, key=X, value=X) return attn_out def save_pretrained(self, path): torch.save(self.state_dict(), f"{path}.pt") def _wtf(self): print("wtf") class AttentionAggregator: def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"): self.embed_dim = embed_dim self.num_heads = num_heads self.device = device self.epochs = epochs self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device) def fit(self, X, Y): print("- fitting Attention-based aggregating function") hstacked_X = self.stack(X) dataset = AggregatorDatasetTorch(hstacked_X, Y) tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True) experiment_name = "attention_aggregator" trainer = Trainer( self.attn, optimizer_name="adamW", lr=1e-3, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=100, evaluate_step=1000, patience=10, experiment_name=experiment_name, device=self.device, checkpoint_path="models/aggregator", ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=tra_dataloader, epochs=self.epochs, ) return self def transform(self, X): # TODO: implement transform h_stacked = self.stack(X) dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole") dataloader = DataLoader(dataset, batch_size=32, shuffle=False) _embeds = [] l_embeds = defaultdict(list) self.attn.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.attn.transform(input_ids) _embeds.append((out.cpu().numpy(), lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def stack(self, data): hstack = self._hstack(data) return hstack def _hstack(self, data): _langs = data[0].keys() l_projections = {} for l in _langs: l_projections[l] = torch.tensor( np.hstack([view[l] for view in data]), dtype=torch.float32 ) return l_projections def _vstack(self, data): return torch.vstack() class AggregatorDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]