diff --git a/refactor/main.py b/refactor/main.py index bea0067..e44adcd 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -28,7 +28,7 @@ def main(args): # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=50, + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, nepochs=100, gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)