test commit
This commit is contained in:
parent
8325262972
commit
4485d97e03
|
@ -9,7 +9,7 @@ import numpy as np
|
||||||
from vgfs.commons import TfidfVectorizerMultilingual
|
from vgfs.commons import TfidfVectorizerMultilingual
|
||||||
from vgfs.learners.svms import MetaClassifier, get_learner
|
from vgfs.learners.svms import MetaClassifier, get_learner
|
||||||
from vgfs.multilingualGen import MultilingualGen
|
from vgfs.multilingualGen import MultilingualGen
|
||||||
from vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||||
from vgfs.vanillaFun import VanillaFunGen
|
from vgfs.vanillaFun import VanillaFunGen
|
||||||
from vgfs.wceGen import WceGen
|
from vgfs.wceGen import WceGen
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ class GeneralizedFunnelling:
|
||||||
self.first_tier_learners.append(wce_vgf)
|
self.first_tier_learners.append(wce_vgf)
|
||||||
|
|
||||||
if self.trasformer_vgf:
|
if self.trasformer_vgf:
|
||||||
transformer_vgf = TransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
model_name=self.transformer_name,
|
model_name=self.transformer_name,
|
||||||
lr=self.lr_transformer,
|
lr=self.lr_transformer,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
|
|
|
@ -21,7 +21,7 @@ transformers.logging.set_verbosity_error()
|
||||||
# TODO: add support to loggers
|
# TODO: add support to loggers
|
||||||
|
|
||||||
|
|
||||||
class TransformerGen:
|
class TextualTransformerGen:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name,
|
model_name,
|
Loading…
Reference in New Issue