diff --git a/refactor/main.py b/refactor/main.py index 17f5a95..d2ab71b 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -28,14 +28,12 @@ def main(args): multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) # posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) - # museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) - # wceEmbedder = WordClassGen(n_jobs=N_JOBS) + museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) + wceEmbedder = WordClassGen(n_jobs=N_JOBS) # rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, # nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) - bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) - bertEmbedder.transform(lX) + # bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) - exit() docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) gfun = Funnelling(first_tier=docEmbedders)