implemented zero-shot experiment code for VanillaFunGen and WordClassGen

This commit is contained in:
andrea 2021-02-04 13:00:18 +01:00
parent f3fafd0f00
commit 495a0b6af9
4 changed files with 44 additions and 19 deletions

54
main.py
View File

@ -17,7 +17,7 @@ def main(args):
print('Running generalized funnelling...') print('Running generalized funnelling...')
data = MultilingualDataset.load(args.dataset) data = MultilingualDataset.load(args.dataset)
data.set_view(languages=['da', 'nl', 'it']) data.set_view(languages=['da'])
data.show_dimensions() data.show_dimensions()
lX, ly = data.training() lX, ly = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()
@ -32,40 +32,64 @@ def main(args):
embedder_list = [] embedder_list = []
if args.post_embedder: if args.post_embedder:
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
zero_shot=args.zero_shot, train_langs=args.zscl_langs, zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
n_jobs=args.n_jobs) n_jobs=args.n_jobs)
embedder_list.append(posteriorEmbedder) embedder_list.append(posteriorEmbedder)
if args.muse_embedder: if args.muse_embedder:
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, museEmbedder = MuseGen(muse_dir=args.muse_dir,
zero_shot=args.zero_shot, train_langs=args.zscl_langs) zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(museEmbedder) embedder_list.append(museEmbedder)
if args.wce_embedder: if args.wce_embedder:
wceEmbedder = WordClassGen(n_jobs=args.n_jobs, wceEmbedder = WordClassGen(zero_shot=args.zero_shot,
zero_shot=args.zero_shot, train_langs=args.zscl_langs) train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(wceEmbedder) embedder_list.append(wceEmbedder)
if args.gru_embedder: if args.gru_embedder:
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, rnnEmbedder = RecurrentGen(multilingualIndex,
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, pretrained_embeddings=lMuse,
zero_shot=args.zero_shot, train_langs=args.zscl_langs, wce=args.rnn_wce,
gpus=args.gpus, n_jobs=args.n_jobs) batch_size=args.batch_rnn,
nepochs=args.nepochs_rnn,
patience=args.patience_rnn,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
gpus=args.gpus,
n_jobs=args.n_jobs)
embedder_list.append(rnnEmbedder) embedder_list.append(rnnEmbedder)
if args.bert_embedder: if args.bert_embedder:
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert, bertEmbedder = BertGen(multilingualIndex,
zero_shot=args.zero_shot, train_langs=args.zscl_langs, batch_size=args.batch_bert,
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs) nepochs=args.nepochs_bert,
patience=args.patience_bert,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
gpus=args.gpus,
n_jobs=args.n_jobs)
embedder_list.append(bertEmbedder) embedder_list.append(bertEmbedder)
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs) meta_parameters=get_params(optimc=args.optimc),
n_jobs=args.n_jobs)
# Init Funnelling Architecture # Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) gfun = Funnelling(first_tier=docEmbedders,
meta_classifier=meta,
n_jobs=args.n_jobs)
# Training --------------------------------------- # Training ---------------------------------------
print('\n[Training Generalized Funnelling]') print('\n[Training Generalized Funnelling]')

View File

@ -116,9 +116,9 @@ class Funnelling:
self.n_jobs = n_jobs self.n_jobs = n_jobs
def fit(self, lX, ly): def fit(self, lX, ly):
print('## Fitting first-tier learners!') print('\n## Fitting first-tier learners!')
lZ = self.first_tier.fit_transform(lX, ly) lZ = self.first_tier.fit_transform(lX, ly)
print('## Fitting meta-learner!') print('\n## Fitting meta-learner!')
self.meta.fit(lZ, ly) self.meta.fit(lZ, ly)
def predict(self, lX): def predict(self, lX):

View File

@ -74,7 +74,7 @@ class NaivePolylingualClassifier:
_sort_if_sparse(lX[lang]) _sort_if_sparse(lX[lang])
models = Parallel(n_jobs=self.n_jobs)\ models = Parallel(n_jobs=self.n_jobs)\
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters).fit)((lX[lang]), ly[lang]) for (delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters, n_jobs=self.n_jobs).fit)((lX[lang]), ly[lang]) for
lang in langs) lang in langs)
self.model = {lang: models[i] for i, lang in enumerate(langs)} self.model = {lang: models[i] for i, lang in enumerate(langs)}

View File

@ -69,7 +69,8 @@ class VanillaFunGen(ViewGen):
self.first_tier_parameters = first_tier_parameters self.first_tier_parameters = first_tier_parameters
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners, self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
parameters=self.first_tier_parameters, n_jobs=self.n_jobs) parameters=self.first_tier_parameters,
n_jobs=self.n_jobs)
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
# Zero shot parameters # Zero shot parameters
self.zero_shot = zero_shot self.zero_shot = zero_shot