getter for gFun and VGFs config

This commit is contained in:
Andrea Pedrotti 2023-03-16 11:41:40 +01:00
parent 9d43ebb23b
commit 17d0003e48
5 changed files with 42 additions and 13 deletions

View File

@ -309,15 +309,21 @@ class GeneralizedFunnelling:
return aggregated return aggregated
def get_config(self): def get_config(self):
print("\n") c = {}
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
print(vgf) vgf_config = vgf.get_config()
print("-" * 50) c.update(vgf_config)
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
}
c["gFun"] = gfun_config
return c
def save(self, save_first_tier=True, save_meta=True): def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")

View File

@ -277,3 +277,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def __str__(self): def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str return str
def get_config(self):
c = super().get_config()
return {"textual_trf": c}

View File

@ -94,3 +94,21 @@ class TransformerGen:
val_lY[lang] = val_Y val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY return tr_lX, tr_lY, val_lX, val_lY
def get_config(self):
return {
"model_name": self.model_name,
"dataset_name": self.dataset_name,
"epochs": self.epochs,
"lr": self.lr,
"batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval,
"max_length": self.max_length,
"print_steps": self.print_steps,
"device": self.device,
"probabilistic": self.probabilistic,
"n_jobs": self.n_jobs,
"evaluate_step": self.evaluate_step,
"verbose": self.verbose,
"patience": self.patience,
}

View File

@ -186,3 +186,6 @@ class VisualTransformerGen(ViewGen, TransformerGen):
def __str__(self): def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str return str
def get_config(self):
return {"visual_trf": super().get_config()}

10
main.py
View File

@ -13,13 +13,10 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- Transformers VGFs: - Transformers VGFs:
- save/load for MT5ForSqeuenceClassification
- freeze params method - freeze params method
- log on step rather than epoch?
- General: - General:
[!] zero-shot setup [!] zero-shot setup
- CLS dataset is loading only "books" domain data - CLS dataset is loading only "books" domain data
- log on wandb also the other VGF results + final results
- documents should be trimmed to the same length (for SVMs we are using way too long tokens) - documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator: - Attention Aggregator:
- experiment with weight init of Attention-aggregator - experiment with weight init of Attention-aggregator
@ -106,9 +103,10 @@ def main(args):
n_jobs=args.n_jobs, n_jobs=args.n_jobs,
) )
wandb.init( config = gfun.get_config()
project="gfun", name=f"gFun-{get_config_name(args)}"
) # TODO: Add config to log wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
gfun.fit(lX, lY) gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave: if args.load_trained is None and not args.nosave: