diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index c4746cd..3d668e7 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -317,7 +317,7 @@ class GeneralizedFunnelling: for vgf in self.first_tier_learners: vgf_config = vgf.get_config() - c.update(vgf_config) + c.update({vgf_config["name"]: vgf_config}) gfun_config = { "id": self._model_id, diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 78eb852..5ed9d16 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -45,11 +45,12 @@ class MT5ForSequenceClassification(nn.Module): def save_pretrained(self, checkpoint_dir): torch.save(self.state_dict(), checkpoint_dir + ".pt") - return + return self def from_pretrained(self, checkpoint_dir): checkpoint_dir += ".pt" - return self.load_state_dict(torch.load(checkpoint_dir)) + self.load_state_dict(torch.load(checkpoint_dir)) + return self class TextualTransformerGen(ViewGen, TransformerGen): @@ -183,7 +184,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): checkpoint_path=os.path.join( "models", "vgfs", - "transformer", + "trained_transformer", self._format_model_name(self.model_name), ), vgf_name="textual_trf", @@ -277,4 +278,4 @@ class TextualTransformerGen(ViewGen, TransformerGen): def get_config(self): c = super().get_config() - return {"textual_trf": c} + return {"name": "textual-trasnformer VGF", "textual_trf": c} diff --git a/gfun/vgfs/vanillaFun.py b/gfun/vgfs/vanillaFun.py index d8cb334..f4f25b4 100644 --- a/gfun/vgfs/vanillaFun.py +++ b/gfun/vgfs/vanillaFun.py @@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen): with open(_path, "wb") as f: pickle.dump(self, f) return self + + def get_config(self): + return {"name": "Vanilla Funnelling VGF"} diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index fa02886..84514e4 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen): return self def get_config(self): - return {"visual_trf": super().get_config()} + return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}