setting up zero-shot experiments (done and tested for WordClassGen)
This commit is contained in:
parent
7f493da0f8
commit
5821325c86
5
main.py
5
main.py
|
|
@ -41,7 +41,8 @@ def main(args):
|
||||||
embedder_list.append(museEmbedder)
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
if args.wce_embedder:
|
if args.wce_embedder:
|
||||||
wceEmbedder = WordClassGen(n_jobs=args.n_jobs)
|
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
||||||
|
zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot
|
||||||
embedder_list.append(wceEmbedder)
|
embedder_list.append(wceEmbedder)
|
||||||
|
|
||||||
if args.gru_embedder:
|
if args.gru_embedder:
|
||||||
|
|
@ -74,7 +75,7 @@ def main(args):
|
||||||
# Testing ----------------------------------------
|
# Testing ----------------------------------------
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time.time()
|
time_te = time.time()
|
||||||
# Zero shot scenario -> setting first tier learners zero_shot param to False
|
# TODO: Zero shot scenario -> setting first tier learners zero_shot param to False
|
||||||
gfun.set_zero_shot(val=False)
|
gfun.set_zero_shot(val=False)
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class WordClassGen(ViewGen):
|
||||||
View Generator (w): generates document representation via Word-Class-Embeddings.
|
View Generator (w): generates document representation via Word-Class-Embeddings.
|
||||||
Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
||||||
"""
|
"""
|
||||||
def __init__(self, n_jobs=-1):
|
def __init__(self, zero_shot=False, train_langs: list = None, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Init WordClassGen.
|
Init WordClassGen.
|
||||||
:param n_jobs: int, number of concurrent workers
|
:param n_jobs: int, number of concurrent workers
|
||||||
|
|
@ -191,6 +191,11 @@ class WordClassGen(ViewGen):
|
||||||
self.langs = None
|
self.langs = None
|
||||||
self.lWce = None
|
self.lWce = None
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
# Zero shot parameters
|
||||||
|
self.zero_shot = zero_shot
|
||||||
|
if train_langs is None:
|
||||||
|
train_langs = ['it']
|
||||||
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
"""
|
"""
|
||||||
|
|
@ -215,19 +220,34 @@ class WordClassGen(ViewGen):
|
||||||
:param lX: dict {lang: indexed documents}
|
:param lX: dict {lang: indexed documents}
|
||||||
:return: document projection to the common latent space.
|
:return: document projection to the common latent space.
|
||||||
"""
|
"""
|
||||||
lX = self.vectorizer.transform(lX)
|
# Testing zero-shot experiments
|
||||||
|
if self.zero_shot:
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
|
lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None}
|
||||||
|
else:
|
||||||
|
lX = self.vectorizer.transform(lX)
|
||||||
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in self.langs)
|
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys()))
|
||||||
lWce = {l: XdotWce[i] for i, l in enumerate(self.langs)}
|
lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys()))}
|
||||||
lWce = _normalize(lWce, l2=True)
|
lWce = _normalize(lWce, l2=True)
|
||||||
return lWce
|
return lWce
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def zero_shot_experiments(self, lX):
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
_lX = {}
|
||||||
|
for lang in self.langs:
|
||||||
|
if lang in self.train_langs:
|
||||||
|
_lX[lang] = lX[lang]
|
||||||
|
else:
|
||||||
|
_lX[lang] = None
|
||||||
|
lX = _lX
|
||||||
|
return lX
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
print('# TODO: WordClassGen has not been configured for zero-shot experiments')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue