From ab7a310b3443154115dcd77b96d2bbad3c859e21 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Fri, 17 Mar 2023 10:44:45 +0100 Subject: [PATCH] todo updates --- dataManager/gFunDataset.py | 1 - gfun/generalizedFunnelling.py | 2 +- gfun/vgfs/multilingualGen.py | 4 ---- gfun/vgfs/textualTransformerGen.py | 6 +----- gfun/vgfs/vanillaFun.py | 5 ----- gfun/vgfs/visualTransformerGen.py | 4 ---- gfun/vgfs/wceGen.py | 4 ---- 7 files changed, 2 insertions(+), 24 deletions(-) diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 6b0b8b2..679c362 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -118,7 +118,6 @@ class gFunDataset: if self.data_langs is None: data_langs = sorted(train_split.geo.unique().tolist()) - # TODO: if data langs is NOT none then we have a problem where we filter df by langs if self.labels is None: labels = train_split.category_name.unique().tolist() diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 14efc6f..c4746cd 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -82,7 +82,7 @@ class GeneralizedFunnelling: self.aggfunc = aggfunc self.load_trained = load_trained self.load_first_tier = ( - True # TODO: i guess we're always going to load at least the fitst tier + True # TODO: i guess we're always going to load at least the first tier ) self.load_meta = load_meta self.dataset_name = dataset_name diff --git a/gfun/vgfs/multilingualGen.py b/gfun/vgfs/multilingualGen.py index 231fb18..485b0ac 100644 --- a/gfun/vgfs/multilingualGen.py +++ b/gfun/vgfs/multilingualGen.py @@ -104,10 +104,6 @@ class MultilingualGen(ViewGen): pickle.dump(self, f) return self - def __str__(self): - _str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n" - return _str - def load_MUSEs(langs, l_vocab, dir_path, cached=False): dir_path = expanduser(dir_path) diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 135f9cc..78eb852 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -193,7 +193,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): ) trainer.train( train_dataloader=tra_dataloader, - eval_dataloader=val_dataloader, # TODO: debug setting + eval_dataloader=val_dataloader, epochs=self.epochs, ) @@ -275,10 +275,6 @@ class TextualTransformerGen(ViewGen, TransformerGen): else: return model_name - def __str__(self): - str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" - return str - def get_config(self): c = super().get_config() return {"textual_trf": c} diff --git a/gfun/vgfs/vanillaFun.py b/gfun/vgfs/vanillaFun.py index 551e5c9..d8cb334 100644 --- a/gfun/vgfs/vanillaFun.py +++ b/gfun/vgfs/vanillaFun.py @@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen): with open(_path, "wb") as f: pickle.dump(self, f) return self - - def __str__(self): - _str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n" - # - parameters: {self.first_tier_parameters} - return _str diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index eb6fdf7..fa02886 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -185,9 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen): pickle.dump(self, f) return self - def __str__(self): - str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" - return str - def get_config(self): return {"visual_trf": super().get_config()} diff --git a/gfun/vgfs/wceGen.py b/gfun/vgfs/wceGen.py index a7889df..f2cd9ee 100644 --- a/gfun/vgfs/wceGen.py +++ b/gfun/vgfs/wceGen.py @@ -40,10 +40,6 @@ class WceGen(ViewGen): "sif": self.sif, } - def __str__(self): - _str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n" - return _str - def save_vgf(self, model_id): import pickle from os.path import join