Implemented funnelling architecture
This commit is contained in:
parent
111f759cd4
commit
94866e5ad8
|
|
@ -77,10 +77,9 @@ class FeatureSet2Posteriors:
|
||||||
|
|
||||||
|
|
||||||
class Funnelling:
|
class Funnelling:
|
||||||
def __init__(self, first_tier: DocEmbedderList, n_jobs=-1):
|
def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1):
|
||||||
self.first_tier = first_tier
|
self.first_tier = first_tier
|
||||||
self.meta = MetaClassifier(
|
self.meta = meta_classifier
|
||||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs)
|
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,14 @@ from argparse import ArgumentParser
|
||||||
from funnelling import *
|
from funnelling import *
|
||||||
from view_generators import *
|
from view_generators import *
|
||||||
from data.dataset_builder import MultilingualDataset
|
from data.dataset_builder import MultilingualDataset
|
||||||
from util.common import MultilingualIndex
|
from util.common import MultilingualIndex, get_params
|
||||||
from util.evaluation import evaluate
|
from util.evaluation import evaluate
|
||||||
from util.results_csv import CSVlog
|
from util.results_csv import CSVlog
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
OPTIMC = True # TODO
|
||||||
N_JOBS = 8
|
N_JOBS = 8
|
||||||
print('Running refactored...')
|
print('Running refactored...')
|
||||||
|
|
||||||
|
|
@ -27,16 +28,36 @@ def main(args):
|
||||||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
||||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||||
|
|
||||||
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
embedder_list = []
|
||||||
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
|
if args.X:
|
||||||
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
|
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
embedder_list.append(posteriorEmbedder)
|
||||||
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
|
|
||||||
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
|
||||||
|
|
||||||
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
|
if args.M:
|
||||||
|
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
|
||||||
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
gfun = Funnelling(first_tier=docEmbedders)
|
if args.W:
|
||||||
|
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
|
||||||
|
embedder_list.append(wceEmbedder)
|
||||||
|
|
||||||
|
if args.G:
|
||||||
|
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
||||||
|
nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
|
embedder_list.append(rnnEmbedder)
|
||||||
|
|
||||||
|
if args.B:
|
||||||
|
bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
|
embedder_list.append(bertEmbedder)
|
||||||
|
|
||||||
|
# Init DocEmbedderList
|
||||||
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
|
meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}]
|
||||||
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf', C=meta_parameters),
|
||||||
|
meta_parameters=get_params(optimc=True))
|
||||||
|
|
||||||
|
# Init Funnelling Architecture
|
||||||
|
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
||||||
|
|
||||||
# Training ---------------------------------------
|
# Training ---------------------------------------
|
||||||
print('\n[Training Generalized Funnelling]')
|
print('\n[Training Generalized Funnelling]')
|
||||||
|
|
@ -64,9 +85,9 @@ def main(args):
|
||||||
print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}')
|
print(f'Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}')
|
||||||
results.add_row(method='gfun',
|
results.add_row(method='gfun',
|
||||||
setting='TODO',
|
setting='TODO',
|
||||||
sif='TODO',
|
sif='True',
|
||||||
zscore='TRUE',
|
zscore='True',
|
||||||
l2='TRUE',
|
l2='True',
|
||||||
dataset='TODO',
|
dataset='TODO',
|
||||||
time_tr=time_tr,
|
time_tr=time_tr,
|
||||||
time_te=time_te,
|
time_te=time_te,
|
||||||
|
|
@ -84,6 +105,11 @@ def main(args):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('--X')
|
||||||
|
parser.add_argument('--M')
|
||||||
|
parser.add_argument('--W')
|
||||||
|
parser.add_argument('--G')
|
||||||
|
parser.add_argument('--B')
|
||||||
parser.add_argument('--gpus', default=None)
|
parser.add_argument('--gpus', default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
||||||
|
|
@ -360,4 +360,12 @@ def pad(index_list, pad_index, max_pad_length=None):
|
||||||
pad_length = min(pad_length, max_pad_length)
|
pad_length = min(pad_length, max_pad_length)
|
||||||
for i, indexes in enumerate(index_list):
|
for i, indexes in enumerate(index_list):
|
||||||
index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length]
|
index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length]
|
||||||
return index_list
|
return index_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_params(optimc=False):
|
||||||
|
if not optimc:
|
||||||
|
return None
|
||||||
|
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
||||||
|
kernel = 'rbf'
|
||||||
|
return [{'kernel': [kernel], 'C': c_range, 'gamma':['auto']}]
|
||||||
Loading…
Reference in New Issue