setting gfun config when loading pre-trained model
This commit is contained in:
parent
732ffbefb1
commit
bef086ab50
|
@ -124,6 +124,16 @@ class GeneralizedFunnelling:
|
|||
epochs=self.epochs,
|
||||
attn_stacking_type=attn_stacking,
|
||||
)
|
||||
|
||||
self._model_id = get_unique_id(
|
||||
self.dataset_name,
|
||||
self.posteriors_vgf,
|
||||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.textual_trf_vgf,
|
||||
self.visual_trf_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
return self
|
||||
|
||||
if self.posteriors_vgf:
|
||||
|
@ -372,6 +382,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained VanillaFun VGF")
|
||||
if self.multilingual_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
|
@ -380,6 +391,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained Multilingual VGF")
|
||||
if self.wce_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
|
@ -388,20 +400,38 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained WCE VGF")
|
||||
if self.textual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||
"models",
|
||||
"vgfs",
|
||||
"textual_transformer",
|
||||
f"textualTransformerGen_{model_id}.pkl",
|
||||
),
|
||||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained Textual Transformer VGF")
|
||||
if self.visual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"visual_transformer",
|
||||
f"visualTransformerGen_{model_id}.pkl",
|
||||
),
|
||||
"rb",
|
||||
print(f"- loaded trained Visual Transformer VGF"),
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
|
||||
if load_meta:
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
metaclassifier = pickle.load(f)
|
||||
print(f"- loaded trained metaclassifier")
|
||||
else:
|
||||
metaclassifier = None
|
||||
return first_tier_learners, metaclassifier, vectorizer
|
||||
|
|
Loading…
Reference in New Issue