implemented zero-shot experiment code for VanillaFunGen and WordClassGen

This commit is contained in:
andrea 2021-02-04 13:00:18 +01:00
parent f3fafd0f00
commit 495a0b6af9
4 changed files with 44 additions and 19 deletions

54
main.py
View File

@ -17,7 +17,7 @@ def main(args):
print('Running generalized funnelling...')
data = MultilingualDataset.load(args.dataset)
data.set_view(languages=['da', 'nl', 'it'])
data.set_view(languages=['da'])
data.show_dimensions()
lX, ly = data.training()
lXte, lyte = data.test()
@ -32,40 +32,64 @@ def main(args):
embedder_list = []
if args.post_embedder:
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(posteriorEmbedder)
if args.muse_embedder:
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
museEmbedder = MuseGen(muse_dir=args.muse_dir,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(museEmbedder)
if args.wce_embedder:
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
wceEmbedder = WordClassGen(zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(wceEmbedder)
if args.gru_embedder:
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
gpus=args.gpus, n_jobs=args.n_jobs)
rnnEmbedder = RecurrentGen(multilingualIndex,
pretrained_embeddings=lMuse,
wce=args.rnn_wce,
batch_size=args.batch_rnn,
nepochs=args.nepochs_rnn,
patience=args.patience_rnn,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
gpus=args.gpus,
n_jobs=args.n_jobs)
embedder_list.append(rnnEmbedder)
if args.bert_embedder:
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
bertEmbedder = BertGen(multilingualIndex,
batch_size=args.batch_bert,
nepochs=args.nepochs_bert,
patience=args.patience_bert,
zero_shot=args.zero_shot,
train_langs=args.zscl_langs,
gpus=args.gpus,
n_jobs=args.n_jobs)
embedder_list.append(bertEmbedder)
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs)
meta_parameters=get_params(optimc=args.optimc),
n_jobs=args.n_jobs)
# Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
gfun = Funnelling(first_tier=docEmbedders,
meta_classifier=meta,
n_jobs=args.n_jobs)
# Training ---------------------------------------
print('\n[Training Generalized Funnelling]')

View File

@ -116,9 +116,9 @@ class Funnelling:
self.n_jobs = n_jobs
def fit(self, lX, ly):
print('## Fitting first-tier learners!')
print('\n## Fitting first-tier learners!')
lZ = self.first_tier.fit_transform(lX, ly)
print('## Fitting meta-learner!')
print('\n## Fitting meta-learner!')
self.meta.fit(lZ, ly)
def predict(self, lX):

View File

@ -74,7 +74,7 @@ class NaivePolylingualClassifier:
_sort_if_sparse(lX[lang])
models = Parallel(n_jobs=self.n_jobs)\
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters).fit)((lX[lang]), ly[lang]) for
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters, n_jobs=self.n_jobs).fit)((lX[lang]), ly[lang]) for
lang in langs)
self.model = {lang: models[i] for i, lang in enumerate(langs)}

View File

@ -69,7 +69,8 @@ class VanillaFunGen(ViewGen):
self.first_tier_parameters = first_tier_parameters
self.n_jobs = n_jobs
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
parameters=self.first_tier_parameters, n_jobs=self.n_jobs)
parameters=self.first_tier_parameters,
n_jobs=self.n_jobs)
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
# Zero shot parameters
self.zero_shot = zero_shot