setting gfun config when loading pre-trained model
This commit is contained in:
parent
732ffbefb1
commit
bef086ab50
|
|
@ -124,6 +124,16 @@ class GeneralizedFunnelling:
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
attn_stacking_type=attn_stacking,
|
attn_stacking_type=attn_stacking,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._model_id = get_unique_id(
|
||||||
|
self.dataset_name,
|
||||||
|
self.posteriors_vgf,
|
||||||
|
self.multilingual_vgf,
|
||||||
|
self.wce_vgf,
|
||||||
|
self.textual_trf_vgf,
|
||||||
|
self.visual_trf_vgf,
|
||||||
|
self.aggfunc,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
|
|
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained VanillaFun VGF")
|
||||||
if self.multilingual_vgf:
|
if self.multilingual_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained Multilingual VGF")
|
||||||
if self.wce_vgf:
|
if self.wce_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained WCE VGF")
|
||||||
if self.textual_trf_vgf:
|
if self.textual_trf_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"textual_transformer",
|
||||||
|
f"textualTransformerGen_{model_id}.pkl",
|
||||||
),
|
),
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
print(f"- loaded trained Textual Transformer VGF")
|
||||||
|
if self.visual_trf_vgf:
|
||||||
|
with open(
|
||||||
|
os.path.join(
|
||||||
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"visual_transformer",
|
||||||
|
f"visualTransformerGen_{model_id}.pkl",
|
||||||
|
),
|
||||||
|
"rb",
|
||||||
|
print(f"- loaded trained Visual Transformer VGF"),
|
||||||
|
) as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
|
||||||
if load_meta:
|
if load_meta:
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||||
) as f:
|
) as f:
|
||||||
metaclassifier = pickle.load(f)
|
metaclassifier = pickle.load(f)
|
||||||
|
print(f"- loaded trained metaclassifier")
|
||||||
else:
|
else:
|
||||||
metaclassifier = None
|
metaclassifier = None
|
||||||
return first_tier_learners, metaclassifier, vectorizer
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue