updated get_config of vgfs + restore model fn for mt5

This commit is contained in:
Andrea Pedrotti 2023-06-12 12:11:38 +02:00
parent 770e8e62be
commit b3b7c69263
4 changed files with 10 additions and 6 deletions

View File

@ -317,7 +317,7 @@ class GeneralizedFunnelling:
for vgf in self.first_tier_learners:
vgf_config = vgf.get_config()
c.update(vgf_config)
c.update({vgf_config["name"]: vgf_config})
gfun_config = {
"id": self._model_id,

View File

@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module):
def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
return self
def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir))
self.load_state_dict(torch.load(checkpoint_dir))
return self
class TextualTransformerGen(ViewGen, TransformerGen):
@ -183,7 +184,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
"trained_transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf",
@ -277,4 +278,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
def get_config(self):
c = super().get_config()
return {"textual_trf": c}
return {"name": "textual-trasnformer VGF", "textual_trf": c}

View File

@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def get_config(self):
return {"name": "Vanilla Funnelling VGF"}

View File

@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
return self
def get_config(self):
return {"visual_trf": super().get_config()}
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}