diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index c9710f7..0d713e3 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -9,7 +9,7 @@ import numpy as np from vgfs.commons import TfidfVectorizerMultilingual from vgfs.learners.svms import MetaClassifier, get_learner from vgfs.multilingualGen import MultilingualGen -from vgfs.transformerGen import TransformerGen +from gfun.vgfs.textualTransformerGen import TextualTransformerGen from vgfs.vanillaFun import VanillaFunGen from vgfs.wceGen import WceGen @@ -98,7 +98,7 @@ class GeneralizedFunnelling: self.first_tier_learners.append(wce_vgf) if self.trasformer_vgf: - transformer_vgf = TransformerGen( + transformer_vgf = TextualTransformerGen( model_name=self.transformer_name, lr=self.lr_transformer, epochs=self.epochs, diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/textualTransformerGen.py similarity index 99% rename from gfun/vgfs/transformerGen.py rename to gfun/vgfs/textualTransformerGen.py index c0e31fb..eb134bd 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -21,7 +21,7 @@ transformers.logging.set_verbosity_error() # TODO: add support to loggers -class TransformerGen: +class TextualTransformerGen: def __init__( self, model_name,