implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
f3fafd0f00
commit
495a0b6af9
54
main.py
54
main.py
|
|
@ -17,7 +17,7 @@ def main(args):
|
||||||
print('Running generalized funnelling...')
|
print('Running generalized funnelling...')
|
||||||
|
|
||||||
data = MultilingualDataset.load(args.dataset)
|
data = MultilingualDataset.load(args.dataset)
|
||||||
data.set_view(languages=['da', 'nl', 'it'])
|
data.set_view(languages=['da'])
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
@ -32,40 +32,64 @@ def main(args):
|
||||||
embedder_list = []
|
embedder_list = []
|
||||||
if args.post_embedder:
|
if args.post_embedder:
|
||||||
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
|
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
|
||||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
zero_shot=args.zero_shot,
|
||||||
|
train_langs=args.zscl_langs,
|
||||||
n_jobs=args.n_jobs)
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
embedder_list.append(posteriorEmbedder)
|
embedder_list.append(posteriorEmbedder)
|
||||||
|
|
||||||
if args.muse_embedder:
|
if args.muse_embedder:
|
||||||
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
museEmbedder = MuseGen(muse_dir=args.muse_dir,
|
||||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
zero_shot=args.zero_shot,
|
||||||
|
train_langs=args.zscl_langs,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
embedder_list.append(museEmbedder)
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
if args.wce_embedder:
|
if args.wce_embedder:
|
||||||
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
wceEmbedder = WordClassGen(zero_shot=args.zero_shot,
|
||||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
train_langs=args.zscl_langs,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
embedder_list.append(wceEmbedder)
|
embedder_list.append(wceEmbedder)
|
||||||
|
|
||||||
if args.gru_embedder:
|
if args.gru_embedder:
|
||||||
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
rnnEmbedder = RecurrentGen(multilingualIndex,
|
||||||
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
pretrained_embeddings=lMuse,
|
||||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
wce=args.rnn_wce,
|
||||||
gpus=args.gpus, n_jobs=args.n_jobs)
|
batch_size=args.batch_rnn,
|
||||||
|
nepochs=args.nepochs_rnn,
|
||||||
|
patience=args.patience_rnn,
|
||||||
|
zero_shot=args.zero_shot,
|
||||||
|
train_langs=args.zscl_langs,
|
||||||
|
gpus=args.gpus,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
embedder_list.append(rnnEmbedder)
|
embedder_list.append(rnnEmbedder)
|
||||||
|
|
||||||
if args.bert_embedder:
|
if args.bert_embedder:
|
||||||
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
bertEmbedder = BertGen(multilingualIndex,
|
||||||
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
batch_size=args.batch_bert,
|
||||||
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
nepochs=args.nepochs_bert,
|
||||||
|
patience=args.patience_bert,
|
||||||
|
zero_shot=args.zero_shot,
|
||||||
|
train_langs=args.zscl_langs,
|
||||||
|
gpus=args.gpus,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
embedder_list.append(bertEmbedder)
|
embedder_list.append(bertEmbedder)
|
||||||
|
|
||||||
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
||||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
|
|
||||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||||
meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs)
|
meta_parameters=get_params(optimc=args.optimc),
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
# Init Funnelling Architecture
|
# Init Funnelling Architecture
|
||||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
gfun = Funnelling(first_tier=docEmbedders,
|
||||||
|
meta_classifier=meta,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
# Training ---------------------------------------
|
# Training ---------------------------------------
|
||||||
print('\n[Training Generalized Funnelling]')
|
print('\n[Training Generalized Funnelling]')
|
||||||
|
|
|
||||||
|
|
@ -116,9 +116,9 @@ class Funnelling:
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
print('## Fitting first-tier learners!')
|
print('\n## Fitting first-tier learners!')
|
||||||
lZ = self.first_tier.fit_transform(lX, ly)
|
lZ = self.first_tier.fit_transform(lX, ly)
|
||||||
print('## Fitting meta-learner!')
|
print('\n## Fitting meta-learner!')
|
||||||
self.meta.fit(lZ, ly)
|
self.meta.fit(lZ, ly)
|
||||||
|
|
||||||
def predict(self, lX):
|
def predict(self, lX):
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class NaivePolylingualClassifier:
|
||||||
_sort_if_sparse(lX[lang])
|
_sort_if_sparse(lX[lang])
|
||||||
|
|
||||||
models = Parallel(n_jobs=self.n_jobs)\
|
models = Parallel(n_jobs=self.n_jobs)\
|
||||||
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters).fit)((lX[lang]), ly[lang]) for
|
(delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters, n_jobs=self.n_jobs).fit)((lX[lang]), ly[lang]) for
|
||||||
lang in langs)
|
lang in langs)
|
||||||
|
|
||||||
self.model = {lang: models[i] for i, lang in enumerate(langs)}
|
self.model = {lang: models[i] for i, lang in enumerate(langs)}
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,8 @@ class VanillaFunGen(ViewGen):
|
||||||
self.first_tier_parameters = first_tier_parameters
|
self.first_tier_parameters = first_tier_parameters
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
|
self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners,
|
||||||
parameters=self.first_tier_parameters, n_jobs=self.n_jobs)
|
parameters=self.first_tier_parameters,
|
||||||
|
n_jobs=self.n_jobs)
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
# Zero shot parameters
|
# Zero shot parameters
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue