Implemented funnelling architecture

This commit is contained in:
andrea 2021-01-25 17:46:03 +01:00
parent 111f759cd4
commit 94866e5ad8
3 changed files with 49 additions and 16 deletions

View File

@ -77,10 +77,9 @@ class FeatureSet2Posteriors:
class Funnelling:
def __init__(self, first_tier: DocEmbedderList, n_jobs=-1):
def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1):
self.first_tier = first_tier
self.meta = MetaClassifier(
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs)
self.meta = meta_classifier
self.n_jobs = n_jobs
def fit(self, lX, ly):

View File

@ -2,13 +2,14 @@ from argparse import ArgumentParser
from funnelling import *
from view_generators import *
from data.dataset_builder import MultilingualDataset
from util.common import MultilingualIndex
from util.common import MultilingualIndex, get_params
from util.evaluation import evaluate
from util.results_csv import CSVlog
from time import time
def main(args):
OPTIMC = True # TODO
N_JOBS = 8
print('Running refactored...')
@ -27,16 +28,36 @@ def main(args):
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
embedder_list = []
if args.X:
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
embedder_list.append(posteriorEmbedder)
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
if args.M:
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
embedder_list.append(museEmbedder)
gfun = Funnelling(first_tier=docEmbedders)
if args.W:
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
embedder_list.append(wceEmbedder)
if args.G:
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
embedder_list.append(rnnEmbedder)
if args.B:
bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
embedder_list.append(bertEmbedder)
# Init DocEmbedderList
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}]
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf', C=meta_parameters),
meta_parameters=get_params(optimc=True))
# Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
# Training ---------------------------------------
print('\n[Training Generalized Funnelling]')
@ -64,9 +85,9 @@ def main(args):
print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}')
results.add_row(method='gfun',
setting='TODO',
sif='TODO',
zscore='TRUE',
l2='TRUE',
sif='True',
zscore='True',
l2='True',
dataset='TODO',
time_tr=time_tr,
time_te=time_te,
@ -84,6 +105,11 @@ def main(args):
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--X')
parser.add_argument('--M')
parser.add_argument('--W')
parser.add_argument('--G')
parser.add_argument('--B')
parser.add_argument('--gpus', default=None)
args = parser.parse_args()
main(args)

View File

@ -360,4 +360,12 @@ def pad(index_list, pad_index, max_pad_length=None):
pad_length = min(pad_length, max_pad_length)
for i, indexes in enumerate(index_list):
index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length]
return index_list
return index_list
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = 'rbf'
return [{'kernel': [kernel], 'C': c_range, 'gamma':['auto']}]