import os import pickle import numpy as np from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict from gfun.vgfs.learners.svms import MetaClassifier, get_learner from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.vanillaFun import VanillaFunGen from gfun.vgfs.visualTransformerGen import VisualTransformerGen from gfun.vgfs.wceGen import WceGen class GeneralizedFunnelling: def __init__( self, posterior, wce, multilingual, textual_transformer, visual_transformer, langs, num_labels, classification_type, embed_dir, n_jobs, batch_size, eval_batch_size, max_length, textual_lr, visual_lr, epochs, patience, evaluate_step, textual_transformer_name, visual_transformer_name, optimc, device, load_trained, dataset_name, probabilistic, aggfunc, load_meta, ): # Setting VFGs ----------- self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual self.textual_trf_vgf = textual_transformer self.visual_trf_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels self.clf_type = classification_type # ------------------------ self.langs = langs self.embed_dir = embed_dir self.cached = True # Textual Transformer VGF params ---------- self.textual_trf_name = textual_transformer_name self.epochs = epochs self.textual_trf_lr = textual_lr self.textual_scheduler = "ReduceLROnPlateau" self.batch_size_trf = batch_size self.eval_batch_size_trf = eval_batch_size self.max_length = max_length self.early_stopping = True self.patience = patience self.evaluate_step = evaluate_step self.device = device # Visual Transformer VGF params ---------- self.visual_trf_name = visual_transformer_name self.visual_trf_lr = visual_lr self.visual_scheduler = "ReduceLROnPlateau" # Metaclassifier params ------------ self.optimc = optimc # ------------------- self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) self.n_jobs = n_jobs self.first_tier_learners = [] self.metaclassifier = None self.aggfunc = aggfunc self.load_trained = load_trained self.load_first_tier = ( True # TODO: i guess we're always going to load at least the first tier ) self.load_meta = load_meta self.dataset_name = dataset_name self._init() def _init(self): print("\n[Init GeneralizedFunnelling]") assert not ( self.aggfunc == "mean" and self.probabilistic is False ), "When using averaging aggreagation function probabilistic must be True" if self.load_trained is not None: # TODO: clean up this code here print( "- loading trained VGFs, metaclassifer and vectorizer" if self.load_meta else "- loading trained VGFs and vectorizer" ) self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load( self.load_trained, load_first_tier=self.load_first_tier, load_meta=self.load_meta, ) if self.metaclassifier is None: self.metaclassifier = MetaClassifier( meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_parameters=get_params(self.optimc), n_jobs=self.n_jobs, ) if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, lr=self.textual_trf_lr, patience=self.patience, num_heads=1, device=self.device, epochs=self.epochs, attn_stacking_type=attn_stacking, ) self._model_id = get_unique_id( self.dataset_name, self.posteriors_vgf, self.multilingual_vgf, self.wce_vgf, self.textual_trf_vgf, self.visual_trf_vgf, self.aggfunc, ) return self if self.posteriors_vgf: fun = VanillaFunGen( base_learner=get_learner(calibrate=True), n_jobs=self.n_jobs, ) self.first_tier_learners.append(fun) if self.multilingual_vgf: multilingual_vgf = MultilingualGen( embed_dir=self.embed_dir, langs=self.langs, n_jobs=self.n_jobs, cached=self.cached, probabilistic=self.probabilistic, ) self.first_tier_learners.append(multilingual_vgf) if self.wce_vgf: wce_vgf = WceGen(n_jobs=self.n_jobs) self.first_tier_learners.append(wce_vgf) if self.textual_trf_vgf: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, model_name=self.textual_trf_name, lr=self.textual_trf_lr, scheduler=self.textual_scheduler, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, max_length=self.max_length, print_steps=50, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, verbose=True, patience=self.patience, device=self.device, classification_type=self.clf_type, ) self.first_tier_learners.append(transformer_vgf) if self.visual_trf_vgf: visual_trasformer_vgf = VisualTransformerGen( dataset_name=self.dataset_name, model_name="vit", lr=self.visual_trf_lr, scheduler=self.visual_scheduler, epochs=self.epochs, batch_size=self.batch_size_trf, batch_size_eval=self.eval_batch_size_trf, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, patience=self.patience, device=self.device, classification_type=self.clf_type, ) self.first_tier_learners.append(visual_trasformer_vgf) if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, lr=self.textual_trf_lr, patience=self.patience, num_heads=1, device=self.device, epochs=self.epochs, attn_stacking_type=attn_stacking, ) self.metaclassifier = MetaClassifier( meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_parameters=get_params(self.optimc), n_jobs=self.n_jobs, ) self._model_id = get_unique_id( self.dataset_name, self.posteriors_vgf, self.multilingual_vgf, self.wce_vgf, self.textual_trf_vgf, self.visual_trf_vgf, self.aggfunc, ) print(f"- model id: {self._model_id}") return self def init_vgfs_vectorizers(self): for vgf in self.first_tier_learners: if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)): vgf.vectorizer = self.vectorizer def fit(self, lX, lY): print("\n[Fitting GeneralizedFunnelling]") if self.load_trained is not None: print( "- loaded first tier learners!" if self.load_meta is False else "- loaded trained model!" ) """ if we are only loading the first tier, we need to transform the training data to train the meta-classifier """ if self.load_first_tier is True and self.load_meta is False: projections = [] for vgf in self.first_tier_learners: l_posteriors = vgf.transform(lX) projections.append(l_posteriors) agg = self.aggregate(projections, lY) self.metaclassifier.fit(agg, lY) return self self.vectorizer.fit(lX) self.init_vgfs_vectorizers() projections = [] print("- fitting first tier learners") for vgf in self.first_tier_learners: l_posteriors = vgf.fit_transform(lX, lY) projections.append(l_posteriors) agg = self.aggregate(projections, lY) self.metaclassifier.fit(agg, lY) return self def transform(self, lX): projections = [] for vgf in self.first_tier_learners: l_posteriors = vgf.transform(lX) projections.append(l_posteriors) agg = self.aggregate(projections) l_out = self.metaclassifier.predict_proba(agg) if self.clf_type == "singlelabel": for lang, preds in l_out.items(): l_out[lang] = predict(preds, clf_type=self.clf_type) return l_out def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def aggregate(self, first_tier_projections, lY=None): if self.aggfunc == "mean": aggregated = self._aggregate_mean(first_tier_projections) elif self.aggfunc == "concat": aggregated = self._aggregate_concat(first_tier_projections) # elif self.aggfunc == "attn": elif "attn" in self.aggfunc: aggregated = self._aggregate_attn(first_tier_projections, lY) else: raise NotImplementedError return aggregated def _aggregate_attn(self, first_tier_projections, lY=None): if lY is None: # at prediction time aggregated = self.attn_aggregator.transform(first_tier_projections) else: # at training time we must fit the attention layer self.attn_aggregator.fit(first_tier_projections, lY) aggregated = self.attn_aggregator.transform(first_tier_projections) return aggregated def _aggregate_concat(self, first_tier_projections): aggregated = {} for lang in self.langs: aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections]) return aggregated def _aggregate_mean(self, first_tier_projections): aggregated = { lang: np.zeros(data.shape) for lang, data in first_tier_projections[0].items() } for lang_projections in first_tier_projections: for lang, projection in lang_projections.items(): aggregated[lang] += projection for lang, projection in aggregated.items(): aggregated[lang] /= len(first_tier_projections) return aggregated def get_config(self): c = {} for vgf in self.first_tier_learners: vgf_config = vgf.get_config() c.update({vgf_config["name"]: vgf_config}) gfun_config = { "id": self._model_id, "aggfunc": self.aggfunc, "optimc": self.optimc, "dataset": self.dataset_name, } c["gFun"] = gfun_config return c def save(self, save_first_tier=True, save_meta=True): print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True) with open( os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.vectorizer, f) if save_first_tier: self.save_first_tier_learners(model_id=self._model_id) if save_meta: with open( os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.metaclassifier, f) return def save_first_tier_learners(self, model_id): for vgf in self.first_tier_learners: vgf.save_vgf(model_id=self._model_id) return self def load(self, model_id, load_first_tier=True, load_meta=True): print(f"- loading model id: {model_id}") first_tier_learners = [] with open( os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb" ) as f: vectorizer = pickle.load(f) if self.posteriors_vgf: with open( os.path.join( "models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) print(f"- loaded trained VanillaFun VGF") if self.multilingual_vgf: with open( os.path.join( "models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) print(f"- loaded trained Multilingual VGF") if self.wce_vgf: with open( os.path.join( "models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) print(f"- loaded trained WCE VGF") if self.textual_trf_vgf: with open( os.path.join( "models", "vgfs", "textual_transformer", f"textualTransformerGen_{model_id}.pkl", ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) print(f"- loaded trained Textual Transformer VGF") if self.visual_trf_vgf: with open( os.path.join( "models", "vgfs", "visual_transformer", f"visualTransformerGen_{model_id}.pkl", ), "rb", print(f"- loaded trained Visual Transformer VGF"), ) as vgf: first_tier_learners.append(pickle.load(vgf)) if load_meta: with open( os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" ) as f: metaclassifier = pickle.load(f) print(f"- loaded trained metaclassifier") else: metaclassifier = None return first_tier_learners, metaclassifier, vectorizer def _load_meta(self): raise NotImplementedError def _load_posterior(self): raise NotImplementedError def _load_multilingual(self): raise NotImplementedError def _load_wce(self): raise NotImplementedError def _load_transformer(self): raise NotImplementedError def get_attn_agg_dim(self, attn_stacking_type): if self.probabilistic and "attn" not in self.aggfunc: return len(self.first_tier_learners) * self.num_labels elif self.probabilistic and "attn" in self.aggfunc: if attn_stacking_type == "concat": return len(self.first_tier_learners) * self.num_labels elif attn_stacking_type == "mean": return self.num_labels else: raise NotImplementedError else: raise NotImplementedError def get_params(optimc=False): if not optimc: return None c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1] kernel = "rbf" return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}] def get_unique_id( dataset_name, posterior, multilingual, wce, textual_transformer, visual_transformer, aggfunc, ): from datetime import datetime now = datetime.now().strftime("%y%m%d") model_id = f"{dataset_name}_" model_id += "p" if posterior else "" model_id += "m" if multilingual else "" model_id += "w" if wce else "" model_id += "t" if textual_transformer else "" model_id += "v" if visual_transformer else "" model_id += f"_{aggfunc}" return f"{model_id}_{now}"