Implemented funnelling architecture

This commit is contained in:
andrea 2021-01-25 17:20:52 +01:00
parent 93436fc596
commit 111f759cd4
1 changed files with 3 additions and 5 deletions

View File

@ -28,14 +28,12 @@ def main(args):
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
# wceEmbedder = WordClassGen(n_jobs=N_JOBS) wceEmbedder = WordClassGen(n_jobs=N_JOBS)
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, # rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) # nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) # bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
bertEmbedder.transform(lX)
exit()
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
gfun = Funnelling(first_tier=docEmbedders) gfun = Funnelling(first_tier=docEmbedders)