From bef086ab50ffba0d02fec1b233b0441913e77b54 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Mon, 12 Jun 2023 15:55:38 +0200 Subject: [PATCH] setting gfun config when loading pre-trained model --- gfun/generalizedFunnelling.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 3d668e7..d1addeb 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -124,6 +124,16 @@ class GeneralizedFunnelling: epochs=self.epochs, attn_stacking_type=attn_stacking, ) + + self._model_id = get_unique_id( + self.dataset_name, + self.posteriors_vgf, + self.multilingual_vgf, + self.wce_vgf, + self.textual_trf_vgf, + self.visual_trf_vgf, + self.aggfunc, + ) return self if self.posteriors_vgf: @@ -372,6 +382,7 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) + print(f"- loaded trained VanillaFun VGF") if self.multilingual_vgf: with open( os.path.join( @@ -380,6 +391,7 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) + print(f"- loaded trained Multilingual VGF") if self.wce_vgf: with open( os.path.join( @@ -388,20 +400,38 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) + print(f"- loaded trained WCE VGF") if self.textual_trf_vgf: with open( os.path.join( - "models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl" + "models", + "vgfs", + "textual_transformer", + f"textualTransformerGen_{model_id}.pkl", ), "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) + print(f"- loaded trained Textual Transformer VGF") + if self.visual_trf_vgf: + with open( + os.path.join( + "models", + "vgfs", + "visual_transformer", + f"visualTransformerGen_{model_id}.pkl", + ), + "rb", + print(f"- loaded trained Visual Transformer VGF"), + ) as vgf: + first_tier_learners.append(pickle.load(vgf)) if load_meta: with open( os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb" ) as f: metaclassifier = pickle.load(f) + print(f"- loaded trained metaclassifier") else: metaclassifier = None return first_tier_learners, metaclassifier, vectorizer