import os import sys sys.path.append(os.path.join(os.getcwd(), "gfun")) import pickle import numpy as np from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator from vgfs.learners.svms import MetaClassifier, get_learner from vgfs.multilingualGen import MultilingualGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen from vgfs.vanillaFun import VanillaFunGen from vgfs.wceGen import WceGen class GeneralizedFunnelling: def __init__( self, posterior, wce, multilingual, transformer, langs, num_labels, embed_dir, n_jobs, batch_size, max_length, lr, epochs, patience, evaluate_step, transformer_name, optimc, device, load_trained, dataset_name, probabilistic, aggfunc, load_meta, ): # Setting VFGs ----------- self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual self.trasformer_vgf = transformer self.probabilistic = probabilistic self.num_labels = num_labels # ------------------------ self.langs = langs self.embed_dir = embed_dir self.cached = True # Transformer VGF params ---------- self.transformer_name = transformer_name self.epochs = epochs self.lr_transformer = lr self.batch_size_transformer = batch_size self.max_length = max_length self.early_stopping = True self.patience = patience self.evaluate_step = evaluate_step self.device = device # Metaclassifier params ------------ self.optimc = optimc # ------------------- self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) self.n_jobs = n_jobs self.first_tier_learners = [] self.metaclassifier = None self.aggfunc = aggfunc self.load_trained = load_trained self.load_first_tier = ( True # TODO: i guess we're always going to load at least the fitst tier ) self.load_meta = load_meta self.dataset_name = dataset_name self._init() def _init(self): print("[Init GeneralizedFunnelling]") assert not ( self.aggfunc == "mean" and self.probabilistic is False ), "When using averaging aggreagation function probabilistic must be True" if self.load_trained is not None: # TODO: clean up this code here print( "- loading trained VGFs, metaclassifer and vectorizer" if self.load_meta else "- loading trained VGFs and vectorizer" ) self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load( self.load_trained, load_first_tier=self.load_first_tier, load_meta=self.load_meta, ) if self.metaclassifier is None: self.metaclassifier = MetaClassifier( meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_parameters=get_params(self.optimc), n_jobs=self.n_jobs, ) if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), out_dim=self.num_labels, lr=self.lr_transformer, patience=self.patience, num_heads=1, device=self.device, epochs=self.epochs, attn_stacking_type=attn_stacking, ) return self if self.posteriors_vgf: fun = VanillaFunGen( base_learner=get_learner(calibrate=True), n_jobs=self.n_jobs, ) self.first_tier_learners.append(fun) if self.multilingual_vgf: multilingual_vgf = MultilingualGen( embed_dir=self.embed_dir, langs=self.langs, n_jobs=self.n_jobs, cached=self.cached, probabilistic=self.probabilistic, ) self.first_tier_learners.append(multilingual_vgf) if self.wce_vgf: wce_vgf = WceGen(n_jobs=self.n_jobs) self.first_tier_learners.append(wce_vgf) if self.trasformer_vgf: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, model_name=self.transformer_name, lr=self.lr_transformer, epochs=self.epochs, batch_size=self.batch_size_transformer, max_length=self.max_length, device=self.device, print_steps=50, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, verbose=True, patience=self.patience, ) self.first_tier_learners.append(transformer_vgf) if "attn" in self.aggfunc: attn_stacking = self.aggfunc.split("_")[1] self.attn_aggregator = AttentionAggregator( embed_dim=self.get_attn_agg_dim(), out_dim=self.num_labels, lr=self.lr_transformer, patience=self.patience, num_heads=1, device=self.device, epochs=self.epochs, attn_stacking_type=attn_stacking, ) self.metaclassifier = MetaClassifier( meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_parameters=get_params(self.optimc), n_jobs=self.n_jobs, ) self._model_id = get_unique_id( self.posteriors_vgf, self.multilingual_vgf, self.wce_vgf, self.trasformer_vgf, self.aggfunc, ) print(f"- model id: {self._model_id}") return self def init_vgfs_vectorizers(self): for vgf in self.first_tier_learners: if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)): vgf.vectorizer = self.vectorizer def fit(self, lX, lY): print("[Fitting GeneralizedFunnelling]") if self.load_trained is not None: print( "- loaded first tier learners!" if self.load_meta is False else "- loaded trained model!" ) if self.load_first_tier is True and self.load_meta is False: # TODO: clean up this code here projections = [] for vgf in self.first_tier_learners: l_posteriors = vgf.transform(lX) projections.append(l_posteriors) agg = self.aggregate(projections, lY) self.metaclassifier.fit(agg, lY) return self self.vectorizer.fit(lX) self.init_vgfs_vectorizers() projections = [] print("- fitting first tier learners") for vgf in self.first_tier_learners: l_posteriors = vgf.fit_transform(lX, lY) projections.append(l_posteriors) agg = self.aggregate(projections, lY) self.metaclassifier.fit(agg, lY) return self def transform(self, lX): projections = [] for vgf in self.first_tier_learners: l_posteriors = vgf.transform(lX) projections.append(l_posteriors) agg = self.aggregate(projections) l_out = self.metaclassifier.predict_proba(agg) return l_out def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def aggregate(self, first_tier_projections, lY=None): if self.aggfunc == "mean": aggregated = self._aggregate_mean(first_tier_projections) elif self.aggfunc == "concat": aggregated = self._aggregate_concat(first_tier_projections) # elif self.aggfunc == "attn": elif "attn" in self.aggfunc: aggregated = self._aggregate_attn(first_tier_projections, lY) else: raise NotImplementedError return aggregated def _aggregate_attn(self, first_tier_projections, lY=None): if lY is None: # at prediction time aggregated = self.attn_aggregator.transform(first_tier_projections) else: # at training time we must fit the attention layer self.attn_aggregator.fit(first_tier_projections, lY) aggregated = self.attn_aggregator.transform(first_tier_projections) return aggregated def _aggregate_concat(self, first_tier_projections): aggregated = {} for lang in self.langs: aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections]) return aggregated def _aggregate_mean(self, first_tier_projections): aggregated = { lang: np.zeros(data.shape) for lang, data in first_tier_projections[0].items() } for lang_projections in first_tier_projections: for lang, projection in lang_projections.items(): aggregated[lang] += projection for lang, projection in aggregated.items(): aggregated[lang] /= len(first_tier_projections) return aggregated def get_config(self): print("\n") print("-" * 50) print("[GeneralizedFunnelling config]") print(f"- model trained on langs: {self.langs}") print("-- View Generating Functions configurations:\n") for vgf in self.first_tier_learners: print(vgf) print("-" * 50) def save(self, save_first_tier=True, save_meta=True): print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True) with open( os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.vectorizer, f) if save_first_tier: self.save_first_tier_learners(model_id=self._model_id) if save_meta: with open( os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.metaclassifier, f) return def save_first_tier_learners(self, model_id): for vgf in self.first_tier_learners: vgf.save_vgf(model_id=self._model_id) return self def load(self, model_id, load_first_tier=True, load_meta=True): print(f"- loading model id: {model_id}") first_tier_learners = [] with open( os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb" ) as f: vectorizer = pickle.load(f) if self.posteriors_vgf: with open( os.path.join( "models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) if self.multilingual_vgf: with open( os.path.join( "models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) if self.wce_vgf: with open( os.path.join( "models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) if self.trasformer_vgf: with open( os.path.join( "models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl" ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) if load_meta: with open( os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" ) as f: metaclassifier = pickle.load(f) else: metaclassifier = None return first_tier_learners, metaclassifier, vectorizer def _load_meta(self): raise NotImplementedError def _load_posterior(self): raise NotImplementedError def _load_multilingual(self): raise NotImplementedError def _load_wce(self): raise NotImplementedError def _load_transformer(self): raise NotImplementedError def get_attn_agg_dim(self, attn_stacking_type="concat"): if self.probabilistic and "attn" not in self.aggfunc: return len(self.first_tier_learners) * self.num_labels elif self.probabilistic and "attn" in self.aggfunc: if attn_stacking_type == "concat": return len(self.first_tier_learners) * self.num_labels elif attn_stacking_type == "mean": return self.num_labels else: raise NotImplementedError else: raise NotImplementedError def get_params(optimc=False): if not optimc: return None c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1] kernel = "rbf" return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}] def get_unique_id(posterior, multilingual, wce, transformer, aggfunc): from datetime import datetime now = datetime.now().strftime("%y%m%d") model_id = "" model_id += "p" if posterior else "" model_id += "m" if multilingual else "" model_id += "w" if wce else "" model_id += "t" if transformer else "" model_id += f"_{aggfunc}" return f"{model_id}_{now}"