diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index a969596..71c9054 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -309,15 +309,21 @@ class GeneralizedFunnelling: return aggregated def get_config(self): - print("\n") - print("-" * 50) - print("[GeneralizedFunnelling config]") - print(f"- model trained on langs: {self.langs}") - print("-- View Generating Functions configurations:\n") + c = {} for vgf in self.first_tier_learners: - print(vgf) - print("-" * 50) + vgf_config = vgf.get_config() + c.update(vgf_config) + + gfun_config = { + "id": self._model_id, + "aggfunc": self.aggfunc, + "optimc": self.optimc, + "dataset": self.dataset_name, + } + + c["gFun"] = gfun_config + return c def save(self, save_first_tier=True, save_meta=True): print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 5bfb5c1..16b70ed 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -277,3 +277,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str + + def get_config(self): + c = super().get_config() + return {"textual_trf": c} diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py index f7f4e41..d3d09a3 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/transformerGen.py @@ -94,3 +94,21 @@ class TransformerGen: val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY + + def get_config(self): + return { + "model_name": self.model_name, + "dataset_name": self.dataset_name, + "epochs": self.epochs, + "lr": self.lr, + "batch_size": self.batch_size, + "batch_size_eval": self.batch_size_eval, + "max_length": self.max_length, + "print_steps": self.print_steps, + "device": self.device, + "probabilistic": self.probabilistic, + "n_jobs": self.n_jobs, + "evaluate_step": self.evaluate_step, + "verbose": self.verbose, + "patience": self.patience, + } \ No newline at end of file diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index ae8b914..c3994e8 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -186,3 +186,6 @@ class VisualTransformerGen(ViewGen, TransformerGen): def __str__(self): str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str + + def get_config(self): + return {"visual_trf": super().get_config()} diff --git a/main.py b/main.py index f272490..ea89bd0 100644 --- a/main.py +++ b/main.py @@ -13,13 +13,10 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - Transformers VGFs: - - save/load for MT5ForSqeuenceClassification - freeze params method - - log on step rather than epoch? - General: [!] zero-shot setup - CLS dataset is loading only "books" domain data - - log on wandb also the other VGF results + final results - documents should be trimmed to the same length (for SVMs we are using way too long tokens) - Attention Aggregator: - experiment with weight init of Attention-aggregator @@ -106,9 +103,10 @@ def main(args): n_jobs=args.n_jobs, ) - wandb.init( - project="gfun", name=f"gFun-{get_config_name(args)}" - ) # TODO: Add config to log + config = gfun.get_config() + + wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config) + gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: