setting up zero-shot experiments (done and tested for MuseGen)
This commit is contained in:
parent
10bed81916
commit
7f493da0f8
24
main.py
24
main.py
|
@ -15,23 +15,14 @@ def main(args):
|
||||||
print('Running generalized funnelling...')
|
print('Running generalized funnelling...')
|
||||||
|
|
||||||
data = MultilingualDataset.load(args.dataset)
|
data = MultilingualDataset.load(args.dataset)
|
||||||
data.set_view(languages=['it', 'da'])
|
data.set_view(languages=['it', 'da', 'nl'])
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
|
|
||||||
# Testing zero shot experiments
|
|
||||||
# zero_shot_setting = True
|
|
||||||
# if zero_shot_setting:
|
|
||||||
# # _lX = {}
|
|
||||||
# _ly = {}
|
|
||||||
# train_langs = ['it']
|
|
||||||
# for train_lang in train_langs:
|
|
||||||
# # _lX[train_lang] = lX[train_lang]
|
|
||||||
# _ly[train_lang] = ly[train_lang]
|
|
||||||
# ly = _ly
|
|
||||||
|
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
zero_shot = True
|
||||||
|
zscl_train_langs = ['it'] # Todo: testing zero shot
|
||||||
|
|
||||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||||
if args.gru_embedder or args.bert_embedder:
|
if args.gru_embedder or args.bert_embedder:
|
||||||
multilingualIndex = MultilingualIndex()
|
multilingualIndex = MultilingualIndex()
|
||||||
|
@ -45,7 +36,8 @@ def main(args):
|
||||||
embedder_list.append(posteriorEmbedder)
|
embedder_list.append(posteriorEmbedder)
|
||||||
|
|
||||||
if args.muse_embedder:
|
if args.muse_embedder:
|
||||||
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, zero_shot=True)
|
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
||||||
|
zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot
|
||||||
embedder_list.append(museEmbedder)
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
if args.wce_embedder:
|
if args.wce_embedder:
|
||||||
|
@ -82,6 +74,8 @@ def main(args):
|
||||||
# Testing ----------------------------------------
|
# Testing ----------------------------------------
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time.time()
|
time_te = time.time()
|
||||||
|
# Zero shot scenario -> setting first tier learners zero_shot param to False
|
||||||
|
gfun.set_zero_shot(val=False)
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||||
time_te = round(time.time() - time_te, 3)
|
time_te = round(time.time() - time_te, 3)
|
||||||
|
@ -111,7 +105,7 @@ def main(args):
|
||||||
microf1=microf1,
|
microf1=microf1,
|
||||||
macrok=macrok,
|
macrok=macrok,
|
||||||
microk=microk,
|
microk=microk,
|
||||||
notes=f'Train langs: {sorted(lX.keys())}')
|
notes=f'Train langs: {sorted(zscl_train_langs)}' if zero_shot else '')
|
||||||
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
|
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
|
||||||
|
|
||||||
overall_time = round(time.time() - time_init, 3)
|
overall_time = round(time.time() - time_init, 3)
|
||||||
|
|
|
@ -48,14 +48,17 @@ class DocEmbedderList:
|
||||||
|
|
||||||
for embedder in self.embedders:
|
for embedder in self.embedders:
|
||||||
lZ = embedder.transform(lX)
|
lZ = embedder.transform(lX)
|
||||||
for lang in langs:
|
for lang in sorted(lZ.keys()):
|
||||||
Z = lZ[lang]
|
Z = lZ[lang]
|
||||||
if lZparts[lang] is None:
|
if lZparts[lang] is None:
|
||||||
lZparts[lang] = Z
|
lZparts[lang] = Z
|
||||||
else:
|
else:
|
||||||
lZparts[lang] += Z
|
lZparts[lang] += Z
|
||||||
n_embedders = len(self.embedders)
|
n_embedders = len(self.embedders)
|
||||||
return {lang: lZparts[lang]/n_embedders for lang in langs} # Averaging feature spaces
|
# Zero shot experiments: removing k:v if v is None (i.e, it is a lang that will be used in zero shot setting)
|
||||||
|
lZparts = {k: v for k, v in lZparts.items() if v is not None}
|
||||||
|
|
||||||
|
return {lang: lZparts[lang]/n_embedders for lang in sorted(lZparts.keys())} # Averaging feature spaces
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
@ -122,3 +125,9 @@ class Funnelling:
|
||||||
lZ = self.first_tier.transform(lX)
|
lZ = self.first_tier.transform(lX)
|
||||||
ly = self.meta.predict(lZ)
|
ly = self.meta.predict(lZ)
|
||||||
return ly
|
return ly
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
for embedder in self.first_tier.embedders:
|
||||||
|
embedder.embedder.set_zero_shot(val)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
|
@ -93,13 +93,18 @@ class VanillaFunGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
self.zero_shot = val
|
||||||
|
print('# TODO: PosteriorsGen has not been configured for zero-shot experiments')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class MuseGen(ViewGen):
|
class MuseGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
View Generator (m): generates document representation via MUSE embeddings (Fasttext multilingual word
|
View Generator (m): generates document representation via MUSE embeddings (Fasttext multilingual word
|
||||||
embeddings). Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
embeddings). Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
||||||
"""
|
"""
|
||||||
def __init__(self, muse_dir='../embeddings', zero_shot=False, n_jobs=-1):
|
def __init__(self, muse_dir='../embeddings', zero_shot=False, train_langs: list = None, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Init the MuseGen.
|
Init the MuseGen.
|
||||||
:param muse_dir: string, path to folder containing muse embeddings
|
:param muse_dir: string, path to folder containing muse embeddings
|
||||||
|
@ -111,7 +116,11 @@ class MuseGen(ViewGen):
|
||||||
self.langs = None
|
self.langs = None
|
||||||
self.lMuse = None
|
self.lMuse = None
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
# Zero shot parameters
|
||||||
self.zero_shot = zero_shot
|
self.zero_shot = zero_shot
|
||||||
|
if train_langs is None:
|
||||||
|
train_langs = ['it']
|
||||||
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
"""
|
"""
|
||||||
|
@ -138,6 +147,7 @@ class MuseGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
# Testing zero-shot experiments
|
# Testing zero-shot experiments
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None}
|
lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None}
|
||||||
else:
|
else:
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
|
@ -148,22 +158,23 @@ class MuseGen(ViewGen):
|
||||||
return lZ
|
return lZ
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
print('## NB: Calling fit_transform!')
|
|
||||||
if self.zero_shot:
|
|
||||||
return self.fit(lX, ly).transform(self.zero_shot_experiments(lX))
|
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
def zero_shot_experiments(self, lX, train_langs: list = ['it']):
|
def zero_shot_experiments(self, lX):
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
_lX = {}
|
_lX = {}
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
if lang in train_langs:
|
if lang in self.train_langs:
|
||||||
_lX[lang] = lX[lang]
|
_lX[lang] = lX[lang]
|
||||||
else:
|
else:
|
||||||
_lX[lang] = None
|
_lX[lang] = None
|
||||||
lX = _lX
|
lX = _lX
|
||||||
return lX
|
return lX
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
self.zero_shot = val
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class WordClassGen(ViewGen):
|
class WordClassGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
|
@ -214,6 +225,11 @@ class WordClassGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
self.zero_shot = val
|
||||||
|
print('# TODO: WordClassGen has not been configured for zero-shot experiments')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class RecurrentGen(ViewGen):
|
class RecurrentGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
|
@ -335,6 +351,11 @@ class RecurrentGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
self.zero_shot = val
|
||||||
|
print('# TODO: RecurrentGen has not been configured for zero-shot experiments')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class BertGen(ViewGen):
|
class BertGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
|
@ -405,3 +426,8 @@ class BertGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def set_zero_shot(self, val: bool):
|
||||||
|
self.zero_shot = val
|
||||||
|
print('# TODO: BertGen has not been configured for zero-shot experiments')
|
||||||
|
return
|
||||||
|
|
Loading…
Reference in New Issue