getter for gFun and VGFs config
This commit is contained in:
parent
9d43ebb23b
commit
17d0003e48
|
@ -309,15 +309,21 @@ class GeneralizedFunnelling:
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
print("\n")
|
c = {}
|
||||||
print("-" * 50)
|
|
||||||
print("[GeneralizedFunnelling config]")
|
|
||||||
print(f"- model trained on langs: {self.langs}")
|
|
||||||
print("-- View Generating Functions configurations:\n")
|
|
||||||
|
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
print(vgf)
|
vgf_config = vgf.get_config()
|
||||||
print("-" * 50)
|
c.update(vgf_config)
|
||||||
|
|
||||||
|
gfun_config = {
|
||||||
|
"id": self._model_id,
|
||||||
|
"aggfunc": self.aggfunc,
|
||||||
|
"optimc": self.optimc,
|
||||||
|
"dataset": self.dataset_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
c["gFun"] = gfun_config
|
||||||
|
return c
|
||||||
|
|
||||||
def save(self, save_first_tier=True, save_meta=True):
|
def save(self, save_first_tier=True, save_meta=True):
|
||||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||||
|
|
|
@ -277,3 +277,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
return str
|
return str
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
c = super().get_config()
|
||||||
|
return {"textual_trf": c}
|
||||||
|
|
|
@ -94,3 +94,21 @@ class TransformerGen:
|
||||||
val_lY[lang] = val_Y
|
val_lY[lang] = val_Y
|
||||||
|
|
||||||
return tr_lX, tr_lY, val_lX, val_lY
|
return tr_lX, tr_lY, val_lX, val_lY
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"dataset_name": self.dataset_name,
|
||||||
|
"epochs": self.epochs,
|
||||||
|
"lr": self.lr,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"batch_size_eval": self.batch_size_eval,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"print_steps": self.print_steps,
|
||||||
|
"device": self.device,
|
||||||
|
"probabilistic": self.probabilistic,
|
||||||
|
"n_jobs": self.n_jobs,
|
||||||
|
"evaluate_step": self.evaluate_step,
|
||||||
|
"verbose": self.verbose,
|
||||||
|
"patience": self.patience,
|
||||||
|
}
|
|
@ -186,3 +186,6 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
return str
|
return str
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {"visual_trf": super().get_config()}
|
||||||
|
|
10
main.py
10
main.py
|
@ -13,13 +13,10 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- Transformers VGFs:
|
- Transformers VGFs:
|
||||||
- save/load for MT5ForSqeuenceClassification
|
|
||||||
- freeze params method
|
- freeze params method
|
||||||
- log on step rather than epoch?
|
|
||||||
- General:
|
- General:
|
||||||
[!] zero-shot setup
|
[!] zero-shot setup
|
||||||
- CLS dataset is loading only "books" domain data
|
- CLS dataset is loading only "books" domain data
|
||||||
- log on wandb also the other VGF results + final results
|
|
||||||
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||||
- Attention Aggregator:
|
- Attention Aggregator:
|
||||||
- experiment with weight init of Attention-aggregator
|
- experiment with weight init of Attention-aggregator
|
||||||
|
@ -106,9 +103,10 @@ def main(args):
|
||||||
n_jobs=args.n_jobs,
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb.init(
|
config = gfun.get_config()
|
||||||
project="gfun", name=f"gFun-{get_config_name(args)}"
|
|
||||||
) # TODO: Add config to log
|
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||||
|
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None and not args.nosave:
|
if args.load_trained is None and not args.nosave:
|
||||||
|
|
Loading…
Reference in New Issue