implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
f3fafd0f00
commit
495a0b6af9
54
main.py
54
main.py
|
|
@ -17,7 +17,7 @@ def main(args):
|
|||
print('Running generalized funnelling...')
|
||||
|
||||
data = MultilingualDataset.load(args.dataset)
|
||||
data.set_view(languages=['da', 'nl', 'it'])
|
||||
data.set_view(languages=['da'])
|
||||
data.show_dimensions()
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
|
@ -32,40 +32,64 @@ def main(args):
|
|||
embedder_list = []
|
||||
if args.post_embedder:
|
||||
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
|
||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||
zero_shot=args.zero_shot,
|
||||
train_langs=args.zscl_langs,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
embedder_list.append(posteriorEmbedder)
|
||||
|
||||
if args.muse_embedder:
|
||||
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
||||
museEmbedder = MuseGen(muse_dir=args.muse_dir,
|
||||
zero_shot=args.zero_shot,
|
||||
train_langs=args.zscl_langs,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
embedder_list.append(museEmbedder)
|
||||
|
||||
if args.wce_embedder:
|
||||
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
||||
wceEmbedder = WordClassGen(zero_shot=args.zero_shot,
|
||||
train_langs=args.zscl_langs,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
embedder_list.append(wceEmbedder)
|
||||
|
||||
if args.gru_embedder:
|
||||
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
||||
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||
gpus=args.gpus, n_jobs=args.n_jobs)
|
||||
rnnEmbedder = RecurrentGen(multilingualIndex,
|
||||
pretrained_embeddings=lMuse,
|
||||
wce=args.rnn_wce,
|
||||
batch_size=args.batch_rnn,
|
||||
nepochs=args.nepochs_rnn,
|
||||
patience=args.patience_rnn,
|
||||
zero_shot=args.zero_shot,
|
||||
train_langs=args.zscl_langs,
|
||||
gpus=args.gpus,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
embedder_list.append(rnnEmbedder)
|
||||
|
||||
if args.bert_embedder:
|
||||
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
||||
bertEmbedder = BertGen(multilingualIndex,
|
||||
batch_size=args.batch_bert,
|
||||
nepochs=args.nepochs_bert,
|
||||
patience=args.patience_bert,
|
||||
zero_shot=args.zero_shot,
|
||||
train_langs=args.zscl_langs,
|
||||
gpus=args.gpus,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
embedder_list.append(bertEmbedder)
|
||||
|
||||
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||
|
||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||
meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs)
|
||||
meta_parameters=get_params(optimc=args.optimc),
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
# Init Funnelling Architecture
|
||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
||||
gfun = Funnelling(first_tier=docEmbedders,
|
||||
meta_classifier=meta,
|
||||
n_jobs=args.n_jobs)
|
||||
|
||||
# Training ---------------------------------------
|
||||
print('\n[Training Generalized Funnelling]')
|
||||
|
|
|
|||
|
|
@ -116,9 +116,9 @@ class Funnelling:
|
|||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, lX, ly):
|
||||
print('## Fitting first-tier learners!')
|
||||
print('\n## Fitting first-tier learners!')
|
||||
lZ = self.first_tier.fit_transform(lX, ly)
|
||||
print('## Fitting meta-learner!')
|
||||
print('\n## Fitting meta-learner!')
|
||||
self.meta.fit(lZ, ly)
|
||||
|
||||
def predict(self, lX):
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class NaivePolylingualClassifier:
|
|||
_sort_if_sparse(lX[lang])
|
||||
|
||||
models = Parallel(n_jobs=self.n_jobs)\
|
||||
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters).fit)((lX[lang]), ly[lang]) for
|
||||
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters, n_jobs=self.n_jobs).fit)((lX[lang]), ly[lang]) for
|
||||
lang in langs)
|
||||
|
||||
self.model = {lang: models[i] for i, lang in enumerate(langs)}
|
||||
|
|
|
|||
|
|
@ -69,7 +69,8 @@ class VanillaFunGen(ViewGen):
|
|||
self.first_tier_parameters = first_tier_parameters
|
||||
self.n_jobs = n_jobs
|
||||
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
|
||||
parameters=self.first_tier_parameters, n_jobs=self.n_jobs)
|
||||
parameters=self.first_tier_parameters,
|
||||
n_jobs=self.n_jobs)
|
||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
# Zero shot parameters
|
||||
self.zero_shot = zero_shot
|
||||
|
|
|
|||
Loading…
Reference in New Issue