diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index ad8cff5..1fffc89 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.getcwd(), "gfun")) import pickle import numpy as np -from vgfs.commons import TfidfVectorizerMultilingual +from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator from vgfs.learners.svms import MetaClassifier, get_learner from vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen @@ -44,6 +44,7 @@ class GeneralizedFunnelling: self.multilingual_vgf = multilingual self.trasformer_vgf = transformer self.probabilistic = probabilistic + self.num_labels = 73 # TODO: hard-coded # ------------------------ self.langs = langs self.embed_dir = embed_dir @@ -81,7 +82,6 @@ class GeneralizedFunnelling: self.load_trained ) # TODO: config like aggfunc, device, n_jobs, etc - return self if self.posteriors_vgf: fun = VanillaFunGen( @@ -121,6 +121,15 @@ class GeneralizedFunnelling: ) self.first_tier_learners.append(transformer_vgf) + if self.aggfunc == "attn": + self.attn_aggregator = AttentionAggregator( + embed_dim=self.get_attn_agg_dim(), + out_dim=self.num_labels, + num_heads=1, + device=self.device, + epochs=self.epochs, + ) + self.metaclassifier = MetaClassifier( meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_parameters=get_params(self.optimc), @@ -160,7 +169,7 @@ class GeneralizedFunnelling: l_posteriors = vgf.fit_transform(lX, lY) projections.append(l_posteriors) - agg = self.aggregate(projections) + agg = self.aggregate(projections, lY) self.metaclassifier.fit(agg, lY) return self @@ -177,15 +186,27 @@ class GeneralizedFunnelling: def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) - def aggregate(self, first_tier_projections): + def aggregate(self, first_tier_projections, lY=None): if self.aggfunc == "mean": aggregated = self._aggregate_mean(first_tier_projections) elif self.aggfunc == "concat": aggregated = self._aggregate_concat(first_tier_projections) + elif self.aggfunc == "attn": + aggregated = self._aggregate_attn(first_tier_projections, lY) else: raise NotImplementedError return aggregated + def _aggregate_attn(self, first_tier_projections, lY=None): + if lY is None: + # at prediction time + aggregated = self.attn_aggregator.transform(first_tier_projections) + else: + # at training time we must fit the attention layer + self.attn_aggregator.fit(first_tier_projections, lY) + aggregated = self.attn_aggregator.transform(first_tier_projections) + return aggregated + def _aggregate_concat(self, first_tier_projections): aggregated = {} for lang in self.langs: @@ -201,7 +222,6 @@ class GeneralizedFunnelling: for lang, projection in lang_projections.items(): aggregated[lang] += projection - # Computing mean for lang, projection in aggregated.items(): aggregated[lang] /= len(first_tier_projections) @@ -281,6 +301,11 @@ class GeneralizedFunnelling: vectorizer = pickle.load(f) return first_tier_learners, metaclassifier, vectorizer + def get_attn_agg_dim(self): + # TODO: hardcoded for now + print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n") + return 146 + def get_params(optimc=False): if not optimc: diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 4cdf04e..28752fc 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -3,13 +3,18 @@ from collections import defaultdict import numpy as np import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import normalize from torch.optim import AdamW +from transformers.modeling_outputs import SequenceClassifierOutput from evaluation.evaluate import evaluate, log_eval +PRINT_ON_EPOCH = 10 + def _normalize(lX, l2=True): return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX @@ -107,6 +112,7 @@ class Trainer: evaluate_step, patience, experiment_name, + checkpoint_path, ): self.device = device self.model = model.to(device) @@ -118,7 +124,7 @@ class Trainer: self.patience = patience self.earlystopping = EarlyStopping( patience=patience, - checkpoint_path="models/vgfs/transformer/", + checkpoint_path=checkpoint_path, verbose=True, experiment_name=experiment_name, ) @@ -163,11 +169,15 @@ class Trainer: for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) - loss = self.loss_fn(y_hat.logits, y.to(self.device)) + if isinstance(y_hat, SequenceClassifierOutput): + loss = self.loss_fn(y_hat.logits, y.to(self.device)) + else: + loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() - if b_idx % self.print_steps == 0: - print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") + if (epoch + 1) % PRINT_ON_EPOCH == 0: + if b_idx % self.print_steps == 0: + print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self def evaluate(self, dataloader): @@ -178,8 +188,12 @@ class Trainer: for b_idx, (x, y, lang) in enumerate(dataloader): y_hat = self.model(x.to(self.device)) - loss = self.loss_fn(y_hat.logits, y.to(self.device)) - predictions = predict(y_hat.logits, classification_type="multilabel") + if isinstance(y_hat, SequenceClassifierOutput): + loss = self.loss_fn(y_hat.logits, y.to(self.device)) + predictions = predict(y_hat.logits, classification_type="multilabel") + else: + loss = self.loss_fn(y_hat, y.to(self.device)) + predictions = predict(y_hat, classification_type="multilabel") for l, _true, _pred in zip(lang, y, predictions): lY[l].append(_true.detach().cpu().numpy()) @@ -240,3 +254,135 @@ class EarlyStopping: def load_model(self, model): _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) return model.from_pretrained(_checkpoint_dir) + + +class AttentionModule(nn.Module): + def __init__(self, embed_dim, num_heads, out_dim): + super().__init__() + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.linear = nn.Linear(embed_dim, out_dim) + + def __call__(self, X): + attn_out, attn_weights = self.attn(query=X, key=X, value=X) + out = self.linear(attn_out) + return out + + def transform(self, X): + attn_out, attn_weights = self.attn(query=X, key=X, value=X) + return attn_out + + def save_pretrained(self, path): + torch.save(self.state_dict(), f"{path}.pt") + + def _wtf(self): + print("wtf") + + +class AttentionAggregator: + def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"): + self.embed_dim = embed_dim + self.num_heads = num_heads + self.device = device + self.epochs = epochs + self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device) + + def fit(self, X, Y): + print("- fitting Attention-based aggregating function") + hstacked_X = self.stack(X) + + dataset = AggregatorDatasetTorch(hstacked_X, Y) + tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True) + + experiment_name = "attention_aggregator" + trainer = Trainer( + self.attn, + optimizer_name="adamW", + lr=1e-3, + loss_fn=torch.nn.CrossEntropyLoss(), + print_steps=100, + evaluate_step=1000, + patience=10, + experiment_name=experiment_name, + device=self.device, + checkpoint_path="models/aggregator", + ) + + trainer.train( + train_dataloader=tra_dataloader, + eval_dataloader=tra_dataloader, + epochs=self.epochs, + ) + return self + + def transform(self, X): + # TODO: implement transform + h_stacked = self.stack(X) + dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole") + dataloader = DataLoader(dataset, batch_size=32, shuffle=False) + + _embeds = [] + l_embeds = defaultdict(list) + + self.attn.eval() + with torch.no_grad(): + for input_ids, lang in dataloader: + input_ids = input_ids.to(self.device) + out = self.attn.transform(input_ids) + _embeds.append((out.cpu().numpy(), lang)) + + for embed, lang in _embeds: + for sample_embed, sample_lang in zip(embed, lang): + l_embeds[sample_lang].append(sample_embed) + + l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} + + return l_embeds + + def stack(self, data): + hstack = self._hstack(data) + return hstack + + def _hstack(self, data): + _langs = data[0].keys() + l_projections = {} + for l in _langs: + l_projections[l] = torch.tensor( + np.hstack([view[l] for view in data]), dtype=torch.float32 + ) + return l_projections + + def _vstack(self, data): + return torch.vstack() + + +class AggregatorDatasetTorch(Dataset): + def __init__(self, lX, lY, split="train"): + self.lX = lX + self.lY = lY + self.split = split + self.langs = [] + self.init() + + def init(self): + self.X = torch.vstack([data for data in self.lX.values()]) + if self.split != "whole": + self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) + self.langs = sum( + [ + v + for v in { + lang: [lang] * len(data) for lang, data in self.lX.items() + }.values() + ], + [], + ) + + return self + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + if self.split == "whole": + return self.X[index], self.langs[index] + return self.X[index], self.Y[index], self.langs[index] diff --git a/gfun/vgfs/learners/svms.py b/gfun/vgfs/learners/svms.py index 93a1931..086e3ff 100644 --- a/gfun/vgfs/learners/svms.py +++ b/gfun/vgfs/learners/svms.py @@ -241,14 +241,6 @@ class MetaClassifier: else: return Z - # def stack(self, lZ, lY=None): - # X_stacked = np.vstack(list(lZ.values())) - # if lY is not None: - # Y_stacked = np.vstack(list(lY.values())) - # return X_stacked, Y_stacked - # else: - # return X_stacked - def predict(self, lZ): lZ = _joblib_transform_multiling( self.standardizer.transform, lZ, n_jobs=self.n_jobs diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 3ee321d..9d86b40 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -140,6 +140,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, + checkpoint_path="models/vgfs/transformer", ) trainer.train( train_dataloader=tra_dataloader, diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 2447589..8025be9 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -124,6 +124,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, + checkpoint_path="models/vgfs/transformer", ) trainer.train(