From 7ed98346a5a88db7fbe1bd1e2e5ad06a6a8b7c84 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Mon, 13 Feb 2023 15:01:50 +0100 Subject: [PATCH] fixed loading function for Attention-based aggregating function when triggered by EarlyStopper --- dataManager/multilingualDatset.py | 3 + gfun/generalizedFunnelling.py | 156 +++++++++++++++++++++++------- gfun/vgfs/commons.py | 128 ++++++++++++++++++------ gfun/vgfs/visualTransformerGen.py | 1 - main.py | 41 ++++---- 5 files changed, 243 insertions(+), 86 deletions(-) diff --git a/dataManager/multilingualDatset.py b/dataManager/multilingualDatset.py index 23dc4c7..7fd53f0 100644 --- a/dataManager/multilingualDatset.py +++ b/dataManager/multilingualDatset.py @@ -171,6 +171,9 @@ class MultilingualDataset: else: langs = sorted(self.multiling_dataset.keys()) return langs + + def num_labels(self): + return self.num_categories() def num_categories(self): return self.lYtr()[self.langs()[0]].shape[1] diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 1fffc89..c4302d0 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -22,6 +22,7 @@ class GeneralizedFunnelling: multilingual, transformer, langs, + num_labels, embed_dir, n_jobs, batch_size, @@ -37,6 +38,7 @@ class GeneralizedFunnelling: dataset_name, probabilistic, aggfunc, + load_meta, ): # Setting VFGs ----------- self.posteriors_vgf = posterior @@ -44,7 +46,7 @@ class GeneralizedFunnelling: self.multilingual_vgf = multilingual self.trasformer_vgf = transformer self.probabilistic = probabilistic - self.num_labels = 73 # TODO: hard-coded + self.num_labels = num_labels # ------------------------ self.langs = langs self.embed_dir = embed_dir @@ -68,6 +70,10 @@ class GeneralizedFunnelling: self.metaclassifier = None self.aggfunc = aggfunc self.load_trained = load_trained + self.load_first_tier = ( + True # TODO: i guess we're always going to load at least the fitst tier + ) + self.load_meta = load_meta self.dataset_name = dataset_name self._init() @@ -77,11 +83,37 @@ class GeneralizedFunnelling: self.aggfunc == "mean" and self.probabilistic is False ), "When using averaging aggreagation function probabilistic must be True" if self.load_trained is not None: - print("- loading trained VGFs, metaclassifer and vectorizer") - self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load( - self.load_trained + # TODO: clean up this code here + print( + "- loading trained VGFs, metaclassifer and vectorizer" + if self.load_meta + else "- loading trained VGFs and vectorizer" ) - # TODO: config like aggfunc, device, n_jobs, etc + self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load( + self.load_trained, + load_first_tier=self.load_first_tier, + load_meta=self.load_meta, + ) + if self.metaclassifier is None: + self.metaclassifier = MetaClassifier( + meta_learner=get_learner(calibrate=True, kernel="rbf"), + meta_parameters=get_params(self.optimc), + n_jobs=self.n_jobs, + ) + + if "attn" in self.aggfunc: + attn_stacking = self.aggfunc.split("_")[1] + self.attn_aggregator = AttentionAggregator( + embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), + out_dim=self.num_labels, + lr=self.lr_transformer, + patience=self.patience, + num_heads=1, + device=self.device, + epochs=self.epochs, + attn_stacking_type=attn_stacking, + ) + return self if self.posteriors_vgf: fun = VanillaFunGen( @@ -112,7 +144,7 @@ class GeneralizedFunnelling: epochs=self.epochs, batch_size=self.batch_size_transformer, max_length=self.max_length, - device="cuda", + device=self.device, print_steps=50, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, @@ -121,13 +153,17 @@ class GeneralizedFunnelling: ) self.first_tier_learners.append(transformer_vgf) - if self.aggfunc == "attn": + if "attn" in self.aggfunc: + attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(), out_dim=self.num_labels, + lr=self.lr_transformer, + patience=self.patience, num_heads=1, device=self.device, epochs=self.epochs, + attn_stacking_type=attn_stacking, ) self.metaclassifier = MetaClassifier( @@ -141,6 +177,7 @@ class GeneralizedFunnelling: self.multilingual_vgf, self.wce_vgf, self.trasformer_vgf, + self.aggfunc, ) print(f"- model id: {self._model_id}") return self @@ -153,11 +190,19 @@ class GeneralizedFunnelling: def fit(self, lX, lY): print("[Fitting GeneralizedFunnelling]") if self.load_trained is not None: - print(f"- loaded trained model! Skipping training...") - # TODO: add support to load only the first tier learners while re-training the metaclassifier - load_only_first_tier = False - if load_only_first_tier: - raise NotImplementedError + print( + "- loaded first tier learners!" + if self.load_meta is False + else "- loaded trained model!" + ) + if self.load_first_tier is True and self.load_meta is False: + # TODO: clean up this code here + projections = [] + for vgf in self.first_tier_learners: + l_posteriors = vgf.transform(lX) + projections.append(l_posteriors) + agg = self.aggregate(projections, lY) + self.metaclassifier.fit(agg, lY) return self self.vectorizer.fit(lX) @@ -191,7 +236,8 @@ class GeneralizedFunnelling: aggregated = self._aggregate_mean(first_tier_projections) elif self.aggfunc == "concat": aggregated = self._aggregate_concat(first_tier_projections) - elif self.aggfunc == "attn": + # elif self.aggfunc == "attn": + elif "attn" in self.aggfunc: aggregated = self._aggregate_attn(first_tier_projections, lY) else: raise NotImplementedError @@ -238,27 +284,41 @@ class GeneralizedFunnelling: print(vgf) print("-" * 50) - def save(self): + def save(self, save_first_tier=True, save_meta=True): print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") - # TODO: save only the first tier learners? what about some model config + sanity checks before loading? - for vgf in self.first_tier_learners: - vgf.save_vgf(model_id=self._model_id) - os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True) - with open( - os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb" - ) as f: - pickle.dump(self.metaclassifier, f) + os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True) with open( os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.vectorizer, f) + + if save_first_tier: + self.save_first_tier_learners(model_id=self._model_id) + + if save_meta: + with open( + os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), + "wb", + ) as f: + pickle.dump(self.metaclassifier, f) return - def load(self, model_id): + def save_first_tier_learners(self, model_id): + for vgf in self.first_tier_learners: + vgf.save_vgf(model_id=self._model_id) + return self + + def load(self, model_id, load_first_tier=True, load_meta=True): print(f"- loading model id: {model_id}") first_tier_learners = [] + + with open( + os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb" + ) as f: + vectorizer = pickle.load(f) + if self.posteriors_vgf: with open( os.path.join( @@ -291,20 +351,43 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) - with open( - os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" - ) as f: - metaclassifier = pickle.load(f) - with open( - os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb" - ) as f: - vectorizer = pickle.load(f) + + if load_meta: + with open( + os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" + ) as f: + metaclassifier = pickle.load(f) + else: + metaclassifier = None return first_tier_learners, metaclassifier, vectorizer - def get_attn_agg_dim(self): - # TODO: hardcoded for now - print("\n[NB: ATTN AGGREGATOR DIM HARD-CODED TO 146]\n") - return 146 + def _load_meta(self): + raise NotImplementedError + + def _load_posterior(self): + raise NotImplementedError + + def _load_multilingual(self): + raise NotImplementedError + + def _load_wce(self): + raise NotImplementedError + + def _load_transformer(self): + raise NotImplementedError + + def get_attn_agg_dim(self, attn_stacking_type="concat"): + if self.probabilistic and "attn" not in self.aggfunc: + return len(self.first_tier_learners) * self.num_labels + elif self.probabilistic and "attn" in self.aggfunc: + if attn_stacking_type == "concat": + return len(self.first_tier_learners) * self.num_labels + elif attn_stacking_type == "mean": + return self.num_labels + else: + raise NotImplementedError + else: + raise NotImplementedError def get_params(optimc=False): @@ -315,7 +398,7 @@ def get_params(optimc=False): return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}] -def get_unique_id(posterior, multilingual, wce, transformer): +def get_unique_id(posterior, multilingual, wce, transformer, aggfunc): from datetime import datetime now = datetime.now().strftime("%y%m%d") @@ -324,4 +407,5 @@ def get_unique_id(posterior, multilingual, wce, transformer): model_id += "m" if multilingual else "" model_id += "w" if wce else "" model_id += "t" if transformer else "" + model_id += f"_{aggfunc}" return f"{model_id}_{now}" diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 28752fc..574b628 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -10,6 +10,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import normalize from torch.optim import AdamW from transformers.modeling_outputs import SequenceClassifierOutput +from sklearn.model_selection import train_test_split from evaluation.evaluate import evaluate, log_eval @@ -158,7 +159,6 @@ class Trainer: self.device ) break - # TODO: maybe a lower lr? self.train_epoch(eval_dataloader, epoch=epoch) print(f"\n- last swipe on eval set") self.earlystopping.save_model(self.model) @@ -176,7 +176,7 @@ class Trainer: loss.backward() self.optimizer.step() if (epoch + 1) % PRINT_ON_EPOCH == 0: - if b_idx % self.print_steps == 0: + if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") return self @@ -209,7 +209,6 @@ class Trainer: class EarlyStopping: - # TODO: add checkpointing + restore model if early stopping + last swipe on validation set def __init__( self, patience, @@ -247,8 +246,8 @@ class EarlyStopping: return True def save_model(self, model): + os.makedirs(self.checkpoint_path, exist_ok=True) _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) - os.makedirs(_checkpoint_dir, exist_ok=True) model.save_pretrained(_checkpoint_dir) def load_model(self, model): @@ -257,51 +256,97 @@ class EarlyStopping: class AttentionModule(nn.Module): - def __init__(self, embed_dim, num_heads, out_dim): + def __init__(self, embed_dim, num_heads, h_dim, out_dim): super().__init__() - self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1) + self.layer_norm = nn.LayerNorm(embed_dim) self.linear = nn.Linear(embed_dim, out_dim) def __call__(self, X): - attn_out, attn_weights = self.attn(query=X, key=X, value=X) - out = self.linear(attn_out) + out, attn_weights = self.attn(query=X, key=X, value=X) + out = self.layer_norm(out) + out = self.linear(out) + # out = self.sigmoid(out) return out + # out = self.relu(out) + # out = self.linear2(out) + # out = self.sigmoid(out) def transform(self, X): - attn_out, attn_weights = self.attn(query=X, key=X, value=X) - return attn_out + return self.__call__(X) + # out, attn_weights = self.attn(query=X, key=X, value=X) + # out = self.layer_norm(out) + # out = self.linear(out) + # out = self.sigmoid(out) + # return out + # out = self.relu(out) + # out = self.linear2(out) + # out = self.sigmoid(out) def save_pretrained(self, path): - torch.save(self.state_dict(), f"{path}.pt") + torch.save(self, f"{path}.pt") + # torch.save(self.state_dict(), f"{path}.pt") - def _wtf(self): - print("wtf") + def from_pretrained(self, path): + return torch.load(f"{path}.pt") class AttentionAggregator: - def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"): + def __init__( + self, + embed_dim, + out_dim, + epochs, + lr, + patience, + attn_stacking_type, + h_dim=512, + num_heads=1, + device="cpu", + ): self.embed_dim = embed_dim + self.h_dim = h_dim + self.out_dim = out_dim + self.patience = patience self.num_heads = num_heads self.device = device self.epochs = epochs - self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device) + self.lr = lr + self.stacking_type = attn_stacking_type + self.tr_batch_size = 512 + self.eval_batch_size = 1024 + self.attn = AttentionModule( + self.embed_dim, self.num_heads, self.h_dim, self.out_dim + ).to(self.device) def fit(self, X, Y): print("- fitting Attention-based aggregating function") hstacked_X = self.stack(X) - dataset = AggregatorDatasetTorch(hstacked_X, Y) - tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True) + tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( + hstacked_X, Y, split=0.2, seed=42 + ) + + tra_dataloader = DataLoader( + AggregatorDatasetTorch(tr_lX, tr_lY, split="train"), + batch_size=self.tr_batch_size, + shuffle=True, + ) + eval_dataloader = DataLoader( + AggregatorDatasetTorch(val_lX, val_lY, split="eval"), + batch_size=self.eval_batch_size, + shuffle=False, + ) experiment_name = "attention_aggregator" trainer = Trainer( self.attn, optimizer_name="adamW", - lr=1e-3, + lr=self.lr, loss_fn=torch.nn.CrossEntropyLoss(), - print_steps=100, - evaluate_step=1000, - patience=10, + print_steps=25, + evaluate_step=50, + patience=self.patience, experiment_name=experiment_name, device=self.device, checkpoint_path="models/aggregator", @@ -309,15 +354,14 @@ class AttentionAggregator: trainer.train( train_dataloader=tra_dataloader, - eval_dataloader=tra_dataloader, + eval_dataloader=eval_dataloader, epochs=self.epochs, ) return self def transform(self, X): - # TODO: implement transform - h_stacked = self.stack(X) - dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole") + hstacked_X = self.stack(X) + dataset = AggregatorDatasetTorch(hstacked_X, lY=None, split="whole") dataloader = DataLoader(dataset, batch_size=32, shuffle=False) _embeds = [] @@ -339,10 +383,13 @@ class AttentionAggregator: return l_embeds def stack(self, data): - hstack = self._hstack(data) + if self.stacking_type == "concat": + hstack = self._concat_stack(data) + elif self.stacking_type == "mean": + hstack = self._mean_stack(data) return hstack - def _hstack(self, data): + def _concat_stack(self, data): _langs = data[0].keys() l_projections = {} for l in _langs: @@ -351,8 +398,31 @@ class AttentionAggregator: ) return l_projections - def _vstack(self, data): - return torch.vstack() + def _mean_stack(self, data): + # TODO: double check this mess + aggregated = {lang: torch.zeros(d.shape) for lang, d in data[0].items()} + for lang_projections in data: + for lang, projection in lang_projections.items(): + aggregated[lang] += projection + + for lang, projection in aggregated.items(): + aggregated[lang] = (aggregated[lang] / len(data)).float() + + return aggregated + + def get_train_val_data(self, lX, lY, split=0.2, seed=42): + tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} + + for lang in lX.keys(): + tr_X, val_X, tr_Y, val_Y = train_test_split( + lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False + ) + tr_lX[lang] = tr_X + tr_lY[lang] = tr_Y + val_lX[lang] = val_X + val_lY[lang] = val_Y + + return tr_lX, tr_lY, val_lX, val_lY class AggregatorDatasetTorch(Dataset): diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 8025be9..1f9915d 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -16,7 +16,6 @@ transformers.logging.set_verbosity_error() class VisualTransformerGen(ViewGen, TransformerGen): - # TODO: probabilistic behaviour def __init__( self, model_name, diff --git a/main.py b/main.py index 3612c76..3d039cf 100644 --- a/main.py +++ b/main.py @@ -13,18 +13,23 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling TODO: - add documentations sphinx - zero-shot setup - - set probabilistic behaviour in Transformer parent-class - - pooling / attention aggregation + - load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that) - test split in MultiNews dataset + - when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it) """ def get_dataset(datasetname): assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported" + RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) + JRC_DATAPATH = expanduser( + "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" + ) MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") + if datasetname == "multinews": dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), @@ -38,11 +43,9 @@ def get_dataset(datasetname): max_labels=args.max_labels, ) elif datasetname == "rcv1-2": - dataset = ( - MultilingualDataset(dataset_name="rcv1-2") - .load(RCV_DATAPATH) - .reduce_data(langs=["en", "it", "fr"], maxn=args.nrows) - ) + dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH) + if args.nrows is not None: + dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows) else: raise NotImplementedError return dataset @@ -55,11 +58,9 @@ def main(args): ): lX, lY = dataset.training() lX_te, lY_te = dataset.test() - # print("[NB: for debug purposes, training set is also used as test set]\n") - # lX_te, lY_te = dataset.training() else: - _lX = dataset.dX - _lY = dataset.dY + lX = dataset.dX + lY = dataset.dY tinit = time() @@ -78,6 +79,7 @@ def main(args): # dataset params ---------------------- dataset_name=args.dataset, langs=dataset.langs(), + num_labels=dataset.num_labels(), # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- @@ -100,22 +102,20 @@ def main(args): aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, + load_meta=args.meta, n_jobs=args.n_jobs, ) # gfun.get_config() gfun.fit(lX, lY) - if args.load_trained is not None: - gfun.save() - - # if not args.load_model: - # gfun.save() + if args.load_trained is None: + gfun.save(save_first_tier=True, save_meta=True) preds = gfun.transform(lX) - train_eval = evaluate(lY, preds) - log_eval(train_eval, phase="train") + # train_eval = evaluate(lY, preds) + # log_eval(train_eval, phase="train") timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") @@ -130,10 +130,11 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) + parser.add_argument("--meta", action="store_true") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("--domains", type=str, default="all") - parser.add_argument("--nrows", type=int, default=100) + parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) # gFUN parameters ---------------------- @@ -148,7 +149,7 @@ if __name__ == "__main__": # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--epochs", type=int, default=1000) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--patience", type=int, default=5)