implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
8968570d82
commit
f3fafd0f00
4
main.py
4
main.py
|
@ -62,7 +62,7 @@ def main(args):
|
||||||
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
||||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||||
meta_parameters=get_params(optimc=args.optimc))
|
meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs)
|
||||||
|
|
||||||
# Init Funnelling Architecture
|
# Init Funnelling Architecture
|
||||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
||||||
|
@ -80,7 +80,7 @@ def main(args):
|
||||||
if args.zero_shot:
|
if args.zero_shot:
|
||||||
gfun.set_zero_shot(val=False)
|
gfun.set_zero_shot(val=False)
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs)
|
||||||
time_te = round(time.time() - time_te, 3)
|
time_te = round(time.time() - time_te, 3)
|
||||||
print(f'Testing completed in {time_te} seconds!')
|
print(f'Testing completed in {time_te} seconds!')
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ class DocEmbedderList:
|
||||||
if isinstance(embedder, VanillaFunGen):
|
if isinstance(embedder, VanillaFunGen):
|
||||||
_tmp.append(embedder)
|
_tmp.append(embedder)
|
||||||
else:
|
else:
|
||||||
_tmp.append(FeatureSet2Posteriors(embedder))
|
_tmp.append(FeatureSet2Posteriors(embedder, n_jobs=embedder.n_jobs))
|
||||||
self.embedders = _tmp
|
self.embedders = _tmp
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
|
|
@ -78,7 +78,7 @@ class VanillaFunGen(ViewGen):
|
||||||
self.train_langs = train_langs
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
print('# Fitting VanillaFunGen (X)...')
|
print('\n# Fitting VanillaFunGen (X)...')
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.langs = sorted(self.train_langs)
|
self.langs = sorted(self.train_langs)
|
||||||
|
@ -150,7 +150,7 @@ class MuseGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting MuseGen (M)...')
|
print('\n# Fitting MuseGen (M)...')
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
|
@ -226,7 +226,7 @@ class WordClassGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting WordClassGen (W)...')
|
print('\n# Fitting WordClassGen (W)...')
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.langs = sorted(self.train_langs)
|
self.langs = sorted(self.train_langs)
|
||||||
|
@ -359,7 +359,7 @@ class RecurrentGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting RecurrentGen (G)...')
|
print('\n# Fitting RecurrentGen (G)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
|
@ -468,7 +468,7 @@ class BertGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting BertGen (M)...')
|
print('\n# Fitting BertGen (M)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
|
|
Loading…
Reference in New Issue