todo updates

This commit is contained in:
Andrea Pedrotti 2023-03-17 10:44:45 +01:00
parent 41647f974a
commit ab7a310b34
7 changed files with 2 additions and 24 deletions

View File

@ -118,7 +118,6 @@ class gFunDataset:
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
if self.labels is None:
labels = train_split.category_name.unique().tolist()

View File

@ -82,7 +82,7 @@ class GeneralizedFunnelling:
self.aggfunc = aggfunc
self.load_trained = load_trained
self.load_first_tier = (
True # TODO: i guess we're always going to load at least the fitst tier
True # TODO: i guess we're always going to load at least the first tier
)
self.load_meta = load_meta
self.dataset_name = dataset_name

View File

@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
pickle.dump(self, f)
return self
def __str__(self):
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
return _str
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
dir_path = expanduser(dir_path)

View File

@ -193,7 +193,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader, # TODO: debug setting
eval_dataloader=val_dataloader,
epochs=self.epochs,
)
@ -275,10 +275,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
else:
return model_name
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def get_config(self):
c = super().get_config()
return {"textual_trf": c}

View File

@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
# - parameters: {self.first_tier_parameters}
return _str

View File

@ -185,9 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def get_config(self):
return {"visual_trf": super().get_config()}

View File

@ -40,10 +40,6 @@ class WceGen(ViewGen):
"sif": self.sif,
}
def __str__(self):
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
return _str
def save_vgf(self, model_id):
import pickle
from os.path import join