setting gfun config when loading pre-trained model

This commit is contained in:
Andrea Pedrotti 2023-06-12 15:55:38 +02:00
parent 732ffbefb1
commit bef086ab50
1 changed files with 31 additions and 1 deletions

View File

@ -124,6 +124,16 @@ class GeneralizedFunnelling:
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
self._model_id = get_unique_id(
self.dataset_name,
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
return self
if self.posteriors_vgf:
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained VanillaFun VGF")
if self.multilingual_vgf:
with open(
os.path.join(
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Multilingual VGF")
if self.wce_vgf:
with open(
os.path.join(
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained WCE VGF")
if self.textual_trf_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
"models",
"vgfs",
"textual_transformer",
f"textualTransformerGen_{model_id}.pkl",
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
print(f"- loaded trained Textual Transformer VGF")
if self.visual_trf_vgf:
with open(
os.path.join(
"models",
"vgfs",
"visual_transformer",
f"visualTransformerGen_{model_id}.pkl",
),
"rb",
print(f"- loaded trained Visual Transformer VGF"),
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
print(f"- loaded trained metaclassifier")
else:
metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer