updated get_config of vgfs + restore model fn for mt5
This commit is contained in:
parent
770e8e62be
commit
b3b7c69263
|
|
@ -317,7 +317,7 @@ class GeneralizedFunnelling:
|
|||
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf_config = vgf.get_config()
|
||||
c.update(vgf_config)
|
||||
c.update({vgf_config["name"]: vgf_config})
|
||||
|
||||
gfun_config = {
|
||||
"id": self._model_id,
|
||||
|
|
|
|||
|
|
@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module):
|
|||
|
||||
def save_pretrained(self, checkpoint_dir):
|
||||
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||
return
|
||||
return self
|
||||
|
||||
def from_pretrained(self, checkpoint_dir):
|
||||
checkpoint_dir += ".pt"
|
||||
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||
self.load_state_dict(torch.load(checkpoint_dir))
|
||||
return self
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
|
@ -183,7 +184,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
checkpoint_path=os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"transformer",
|
||||
"trained_transformer",
|
||||
self._format_model_name(self.model_name),
|
||||
),
|
||||
vgf_name="textual_trf",
|
||||
|
|
@ -277,4 +278,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
|
||||
def get_config(self):
|
||||
c = super().get_config()
|
||||
return {"textual_trf": c}
|
||||
return {"name": "textual-trasnformer VGF", "textual_trf": c}
|
||||
|
|
|
|||
|
|
@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
|
|||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def get_config(self):
|
||||
return {"name": "Vanilla Funnelling VGF"}
|
||||
|
|
|
|||
|
|
@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
return self
|
||||
|
||||
def get_config(self):
|
||||
return {"visual_trf": super().get_config()}
|
||||
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}
|
||||
|
|
|
|||
Loading…
Reference in New Issue