import os from collections import defaultdict import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import normalize from torch.optim import AdamW from transformers.modeling_outputs import SequenceClassifierOutput from sklearn.model_selection import train_test_split from evaluation.evaluate import evaluate, log_eval PRINT_ON_EPOCH = 10 def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX def XdotM(X, M, sif): E = X.dot(M) if sif: E = remove_pc(E, npc=1) return E def remove_pc(X, npc=1): """ Remove the projection on the principal components :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: XX[i, :] is the data point after removing its projection """ pc = compute_pc(X, npc) if npc == 1: XX = X - X.dot(pc.transpose()) * pc else: XX = X - X.dot(pc.transpose()).dot(pc) return XX def compute_pc(X, npc=1): """ Compute the principal components. :param X: X[i,:] is a data point :param npc: number of principal components to remove :return: component_[i,:] is the i-th pc """ if isinstance(X, np.matrix): X = np.asarray(X) svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0) svd.fit(X) return svd.components_ def predict(logits, classification_type="multilabel"): """ Converts soft precictions to hard predictions [0,1] """ if classification_type == "multilabel": prediction = torch.sigmoid(logits) > 0.5 elif classification_type == "singlelabel": prediction = torch.argmax(logits, dim=1).view(-1, 1) else: print("unknown classification type") return prediction.detach().cpu().numpy() class TfidfVectorizerMultilingual: def __init__(self, **kwargs): self.kwargs = kwargs def fit(self, lX, ly=None): self.langs = sorted(lX.keys()) self.vectorizer = { l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs } return self def transform(self, lX): return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs} def fit_transform(self, lX, ly=None): return self.fit(lX, ly).transform(lX) def vocabulary(self, l=None): if l is None: return {l: self.vectorizer[l].vocabulary_ for l in self.langs} else: return self.vectorizer[l].vocabulary_ def get_analyzer(self, l=None): if l is None: return {l: self.vectorizer[l].build_analyzer() for l in self.langs} else: return self.vectorizer[l].build_analyzer() class Trainer: def __init__( self, model, optimizer_name, device, loss_fn, lr, print_steps, evaluate_step, patience, experiment_name, checkpoint_path, ): self.device = device self.model = model.to(device) self.optimizer = self.init_optimizer(optimizer_name, lr) self.evaluate_steps = evaluate_step self.loss_fn = loss_fn.to(device) self.print_steps = print_steps self.experiment_name = experiment_name self.patience = patience self.earlystopping = EarlyStopping( patience=patience, checkpoint_path=checkpoint_path, verbose=True, experiment_name=experiment_name, ) def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": return AdamW(self.model.parameters(), lr=lr) else: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): print( f"""- Training params for {self.experiment_name}: - epochs: {epochs} - learning rate: {self.optimizer.defaults['lr']} - train batch size: {train_dataloader.batch_size} - eval batch size: {eval_dataloader.batch_size} - max len: {train_dataloader.dataset.X.shape[-1]} - patience: {self.earlystopping.patience}\n""" ) for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: metric_watcher = self.evaluate(eval_dataloader) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: print( f"- restoring best model from epoch {self.earlystopping.best_epoch}" ) self.model = self.earlystopping.load_model(self.model).to( self.device ) break self.train_epoch(eval_dataloader, epoch=epoch) print(f"\n- last swipe on eval set") self.earlystopping.save_model(self.model) return self.model def train_epoch(self, dataloader, epoch): self.model.train() for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) if isinstance(y_hat, SequenceClassifierOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) else: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader): self.model.eval() lY = defaultdict(list) lY_hat = defaultdict(list) for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) if isinstance(y_hat, SequenceClassifierOutput): loss = self.loss_fn(y_hat.logits, y.to(self.device)) predictions = predict(y_hat.logits, classification_type="multilabel") else: loss = self.loss_fn(y_hat, y.to(self.device)) predictions = predict(y_hat, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) lY_hat[l].append(_pred) for lang in lY: lY[lang] = np.vstack(lY[lang]) lY_hat[lang] = np.vstack(lY_hat[lang]) l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation") return average_metrics[0] # macro-F1 class EarlyStopping: def __init__( self, patience, checkpoint_path, experiment_name, min_delta=0, verbose=True, ): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = 0 self.best_epoch = None self.verbose = verbose self.checkpoint_path = checkpoint_path self.experiment_name = experiment_name def __call__(self, validation, model, epoch): if validation > self.best_score: print( f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" ) self.best_score = validation self.counter = 0 self.best_epoch = epoch self.save_model(model) elif validation < (self.best_score + self.min_delta): self.counter += 1 print( f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" ) if self.counter >= self.patience: if self.verbose: print(f"- earlystopping: Early stopping at epoch {epoch}") return True def save_model(self, model): os.makedirs(self.checkpoint_path, exist_ok=True) _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) model.save_pretrained(_checkpoint_dir) def load_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) return model.from_pretrained(_checkpoint_dir) class AttentionModule(nn.Module): def __init__(self, embed_dim, num_heads, h_dim, out_dim): super().__init__() self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1) self.layer_norm = nn.LayerNorm(embed_dim) self.linear = nn.Linear(embed_dim, out_dim) def __call__(self, X): out, attn_weights = self.attn(query=X, key=X, value=X) out = self.layer_norm(out) out = self.linear(out) # out = self.sigmoid(out) return out # out = self.relu(out) # out = self.linear2(out) # out = self.sigmoid(out) def transform(self, X): return self.__call__(X) # out, attn_weights = self.attn(query=X, key=X, value=X) # out = self.layer_norm(out) # out = self.linear(out) # out = self.sigmoid(out) # return out # out = self.relu(out) # out = self.linear2(out) # out = self.sigmoid(out) def save_pretrained(self, path): torch.save(self, f"{path}.pt") # torch.save(self.state_dict(), f"{path}.pt") def from_pretrained(self, path): return torch.load(f"{path}.pt") class AttentionAggregator: def __init__( self, embed_dim, out_dim, epochs, lr, patience, attn_stacking_type, h_dim=512, num_heads=1, device="cpu", ): self.embed_dim = embed_dim self.h_dim = h_dim self.out_dim = out_dim self.patience = patience self.num_heads = num_heads self.device = device self.epochs = epochs self.lr = lr self.stacking_type = attn_stacking_type self.tr_batch_size = 512 self.eval_batch_size = 1024 self.attn = AttentionModule( self.embed_dim, self.num_heads, self.h_dim, self.out_dim ).to(self.device) def fit(self, X, Y): print("- fitting Attention-based aggregating function") hstacked_X = self.stack(X) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( hstacked_X, Y, split=0.2, seed=42 ) tra_dataloader = DataLoader( AggregatorDatasetTorch(tr_lX, tr_lY, split="train"), batch_size=self.tr_batch_size, shuffle=True, ) eval_dataloader = DataLoader( AggregatorDatasetTorch(val_lX, val_lY, split="eval"), batch_size=self.eval_batch_size, shuffle=False, ) experiment_name = "attention_aggregator" trainer = Trainer( self.attn, optimizer_name="adamW", lr=self.lr, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=25, evaluate_step=50, patience=self.patience, experiment_name=experiment_name, device=self.device, checkpoint_path="models/aggregator", ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=eval_dataloader, epochs=self.epochs, ) return self def transform(self, X): hstacked_X = self.stack(X) dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole") dataloader = DataLoader(dataset, batch_size=32, shuffle=False) _embeds = [] l_embeds = defaultdict(list) self.attn.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.attn.transform(input_ids) _embeds.append((out.cpu().numpy(), lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def stack(self, data): if self.stacking_type == "concat": hstack = self._concat_stack(data) elif self.stacking_type == "mean": hstack = self._mean_stack(data) return hstack def _concat_stack(self, data): _langs = data[0].keys() l_projections = {} for l in _langs: l_projections[l] = torch.tensor( np.hstack([view[l] for view in data]), dtype=torch.float32 ) return l_projections def _mean_stack(self, data): # TODO: double check this mess aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()} for lang_projections in data: for lang, projection in lang_projections.items(): aggregated[lang] += projection for lang, projection in aggregated.items(): aggregated[lang] = (aggregated[lang] / len(data)).float() return aggregated def get_train_val_data(self, lX, lY, split=0.2, seed=42): tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y val_lX[lang] = val_X val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY class AggregatorDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]