Implemented funnelling architecture

This commit is contained in:
andrea 2021-01-25 17:46:03 +01:00
parent 111f759cd4
commit 94866e5ad8
3 changed files with 49 additions and 16 deletions

View File

@ -77,10 +77,9 @@ class FeatureSet2Posteriors:
class Funnelling: class Funnelling:
def __init__(self, first_tier: DocEmbedderList, n_jobs=-1): def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1):
self.first_tier = first_tier self.first_tier = first_tier
self.meta = MetaClassifier( self.meta = meta_classifier
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs)
self.n_jobs = n_jobs self.n_jobs = n_jobs
def fit(self, lX, ly): def fit(self, lX, ly):

View File

@ -2,13 +2,14 @@ from argparse import ArgumentParser
from funnelling import * from funnelling import *
from view_generators import * from view_generators import *
from data.dataset_builder import MultilingualDataset from data.dataset_builder import MultilingualDataset
from util.common import MultilingualIndex from util.common import MultilingualIndex, get_params
from util.evaluation import evaluate from util.evaluation import evaluate
from util.results_csv import CSVlog from util.results_csv import CSVlog
from time import time from time import time
def main(args): def main(args):
OPTIMC = True # TODO
N_JOBS = 8 N_JOBS = 8
print('Running refactored...') print('Running refactored...')
@ -27,16 +28,36 @@ def main(args):
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH) lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) embedder_list = []
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS) if args.X:
wceEmbedder = WordClassGen(n_jobs=N_JOBS) posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256, embedder_list.append(posteriorEmbedder)
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder]) if args.M:
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
embedder_list.append(museEmbedder)
gfun = Funnelling(first_tier=docEmbedders) if args.W:
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
embedder_list.append(wceEmbedder)
if args.G:
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
embedder_list.append(rnnEmbedder)
if args.B:
bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
embedder_list.append(bertEmbedder)
# Init DocEmbedderList
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}]
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf', C=meta_parameters),
meta_parameters=get_params(optimc=True))
# Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
# Training --------------------------------------- # Training ---------------------------------------
print('\n[Training Generalized Funnelling]') print('\n[Training Generalized Funnelling]')
@ -64,9 +85,9 @@ def main(args):
print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}') print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}')
results.add_row(method='gfun', results.add_row(method='gfun',
setting='TODO', setting='TODO',
sif='TODO', sif='True',
zscore='TRUE', zscore='True',
l2='TRUE', l2='True',
dataset='TODO', dataset='TODO',
time_tr=time_tr, time_tr=time_tr,
time_te=time_te, time_te=time_te,
@ -84,6 +105,11 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--X')
parser.add_argument('--M')
parser.add_argument('--W')
parser.add_argument('--G')
parser.add_argument('--B')
parser.add_argument('--gpus', default=None) parser.add_argument('--gpus', default=None)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -360,4 +360,12 @@ def pad(index_list, pad_index, max_pad_length=None):
pad_length = min(pad_length, max_pad_length) pad_length = min(pad_length, max_pad_length)
for i, indexes in enumerate(index_list): for i, indexes in enumerate(index_list):
index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length] index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length]
return index_list return index_list
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = 'rbf'
return [{'kernel': [kernel], 'C': c_range, 'gamma':['auto']}]