From 94866e5ad81e848973c2a54f8970dc86a8a0d009 Mon Sep 17 00:00:00 2001 From: andrea Date: Mon, 25 Jan 2021 17:46:03 +0100 Subject: [PATCH] Implemented funnelling architecture --- refactor/funnelling.py | 5 ++--- refactor/main.py | 50 +++++++++++++++++++++++++++++++---------- refactor/util/common.py | 10 ++++++++- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/refactor/funnelling.py b/refactor/funnelling.py index 33fcce3..6c79ae9 100644 --- a/refactor/funnelling.py +++ b/refactor/funnelling.py @@ -77,10 +77,9 @@ class FeatureSet2Posteriors: class Funnelling: - def __init__(self, first_tier: DocEmbedderList, n_jobs=-1): + def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1): self.first_tier = first_tier - self.meta = MetaClassifier( - SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs) + self.meta = meta_classifier self.n_jobs = n_jobs def fit(self, lX, ly): diff --git a/refactor/main.py b/refactor/main.py index d2ab71b..a1f5eef 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -2,13 +2,14 @@ from argparse import ArgumentParser from funnelling import * from view_generators import * from data.dataset_builder import MultilingualDataset -from util.common import MultilingualIndex +from util.common import MultilingualIndex, get_params from util.evaluation import evaluate from util.results_csv import CSVlog from time import time def main(args): + OPTIMC = True # TODO N_JOBS = 8 print('Running refactored...') @@ -27,16 +28,36 @@ def main(args): lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH) multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) - # posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) - museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) - wceEmbedder = WordClassGen(n_jobs=N_JOBS) - # rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, - # nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) - # bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) + embedder_list = [] + if args.X: + posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) + embedder_list.append(posteriorEmbedder) - docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) + if args.M: + museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) + embedder_list.append(museEmbedder) - gfun = Funnelling(first_tier=docEmbedders) + if args.W: + wceEmbedder = WordClassGen(n_jobs=N_JOBS) + embedder_list.append(wceEmbedder) + + if args.G: + rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, + nepochs=250, gpus=args.gpus, n_jobs=N_JOBS) + embedder_list.append(rnnEmbedder) + + if args.B: + bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS) + embedder_list.append(bertEmbedder) + + # Init DocEmbedderList + docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) + meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}] + meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf', C=meta_parameters), + meta_parameters=get_params(optimc=True)) + + # Init Funnelling Architecture + gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) # Training --------------------------------------- print('\n[Training Generalized Funnelling]') @@ -64,9 +85,9 @@ def main(args): print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}') results.add_row(method='gfun', setting='TODO', - sif='TODO', - zscore='TRUE', - l2='TRUE', + sif='True', + zscore='True', + l2='True', dataset='TODO', time_tr=time_tr, time_te=time_te, @@ -84,6 +105,11 @@ def main(args): if __name__ == '__main__': parser = ArgumentParser() + parser.add_argument('--X') + parser.add_argument('--M') + parser.add_argument('--W') + parser.add_argument('--G') + parser.add_argument('--B') parser.add_argument('--gpus', default=None) args = parser.parse_args() main(args) diff --git a/refactor/util/common.py b/refactor/util/common.py index 575570a..3ffda78 100644 --- a/refactor/util/common.py +++ b/refactor/util/common.py @@ -360,4 +360,12 @@ def pad(index_list, pad_index, max_pad_length=None): pad_length = min(pad_length, max_pad_length) for i, indexes in enumerate(index_list): index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length] - return index_list \ No newline at end of file + return index_list + + +def get_params(optimc=False): + if not optimc: + return None + c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1] + kernel = 'rbf' + return [{'kernel': [kernel], 'C': c_range, 'gamma':['auto']}] \ No newline at end of file