todo updates
This commit is contained in:
parent
41647f974a
commit
ab7a310b34
|
@ -118,7 +118,6 @@ class gFunDataset:
|
||||||
|
|
||||||
if self.data_langs is None:
|
if self.data_langs is None:
|
||||||
data_langs = sorted(train_split.geo.unique().tolist())
|
data_langs = sorted(train_split.geo.unique().tolist())
|
||||||
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
|
||||||
if self.labels is None:
|
if self.labels is None:
|
||||||
labels = train_split.category_name.unique().tolist()
|
labels = train_split.category_name.unique().tolist()
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ class GeneralizedFunnelling:
|
||||||
self.aggfunc = aggfunc
|
self.aggfunc = aggfunc
|
||||||
self.load_trained = load_trained
|
self.load_trained = load_trained
|
||||||
self.load_first_tier = (
|
self.load_first_tier = (
|
||||||
True # TODO: i guess we're always going to load at least the fitst tier
|
True # TODO: i guess we're always going to load at least the first tier
|
||||||
)
|
)
|
||||||
self.load_meta = load_meta
|
self.load_meta = load_meta
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
|
|
|
@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
|
|
||||||
return _str
|
|
||||||
|
|
||||||
|
|
||||||
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
||||||
dir_path = expanduser(dir_path)
|
dir_path = expanduser(dir_path)
|
||||||
|
|
|
@ -193,7 +193,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
eval_dataloader=val_dataloader, # TODO: debug setting
|
eval_dataloader=val_dataloader,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -275,10 +275,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
else:
|
else:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
|
||||||
return str
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
c = super().get_config()
|
c = super().get_config()
|
||||||
return {"textual_trf": c}
|
return {"textual_trf": c}
|
||||||
|
|
|
@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
|
||||||
with open(_path, "wb") as f:
|
with open(_path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
|
|
||||||
# - parameters: {self.first_tier_parameters}
|
|
||||||
return _str
|
|
||||||
|
|
|
@ -185,9 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
|
||||||
return str
|
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {"visual_trf": super().get_config()}
|
return {"visual_trf": super().get_config()}
|
||||||
|
|
|
@ -40,10 +40,6 @@ class WceGen(ViewGen):
|
||||||
"sif": self.sif,
|
"sif": self.sif,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
|
|
||||||
return _str
|
|
||||||
|
|
||||||
def save_vgf(self, model_id):
|
def save_vgf(self, model_id):
|
||||||
import pickle
|
import pickle
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
Loading…
Reference in New Issue