diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 88009b4..5044583 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -47,8 +47,8 @@ class GeneralizedFunnelling: self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual - self.textual_trasformer_vgf = textual_transformer - self.visual_transformer_vgf = visual_transformer + self.textual_trf_vgf = textual_transformer + self.visual_trf_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels # ------------------------ @@ -56,7 +56,7 @@ class GeneralizedFunnelling: self.embed_dir = embed_dir self.cached = True # Textual Transformer VGF params ---------- - self.textaul_transformer_name = textual_transformer_name + self.textual_trf_name = textual_transformer_name self.epochs = epochs self.lr_transformer = lr self.batch_size_transformer = batch_size @@ -66,7 +66,7 @@ class GeneralizedFunnelling: self.evaluate_step = evaluate_step self.device = device # Visual Transformer VGF params ---------- - self.visual_transformer_name = visual_transformer_name + self.visual_trf_name = visual_transformer_name # Metaclassifier params ------------ self.optimc = optimc # ------------------- @@ -142,10 +142,10 @@ class GeneralizedFunnelling: wce_vgf = WceGen(n_jobs=self.n_jobs) self.first_tier_learners.append(wce_vgf) - if self.textual_trasformer_vgf: + if self.textual_trf_vgf: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, - model_name=self.textaul_transformer_name, + model_name=self.textual_trf_name, lr=self.lr_transformer, epochs=self.epochs, batch_size=self.batch_size_transformer, @@ -159,7 +159,7 @@ class GeneralizedFunnelling: ) self.first_tier_learners.append(transformer_vgf) - if self.visual_transformer_vgf: + if self.visual_trf_vgf: visual_trasformer_vgf = VisualTransformerGen( dataset_name=self.dataset_name, model_name="vit", @@ -198,8 +198,8 @@ class GeneralizedFunnelling: self.posteriors_vgf, self.multilingual_vgf, self.wce_vgf, - self.textual_trasformer_vgf, - self.visual_transformer_vgf, + self.textual_trf_vgf, + self.visual_trf_vgf, self.aggfunc, ) print(f"- model id: {self._model_id}") @@ -373,7 +373,7 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) - if self.textual_trasformer_vgf: + if self.textual_trf_vgf: with open( os.path.join( "models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"