setting up zero-shot experiments (implemented for Recurrent and Bert but not tested for Bert)
This commit is contained in:
parent
6361a4eba0
commit
ab3bacb29c
|
@ -363,6 +363,8 @@ class RecurrentGen(ViewGen):
|
||||||
:param lX: dict {lang: indexed documents}
|
:param lX: dict {lang: indexed documents}
|
||||||
:return: documents projected to the common latent space.
|
:return: documents projected to the common latent space.
|
||||||
"""
|
"""
|
||||||
|
if self.zero_shot:
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
data = {}
|
data = {}
|
||||||
for lang in lX.keys():
|
for lang in lX.keys():
|
||||||
indexed = index(data=lX[lang],
|
indexed = index(data=lX[lang],
|
||||||
|
@ -381,6 +383,12 @@ class RecurrentGen(ViewGen):
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def zero_shot_experiments(self, lX):
|
||||||
|
for lang in sorted(lX.keys()):
|
||||||
|
if lang not in self.train_langs:
|
||||||
|
lX.pop(lang)
|
||||||
|
return lX
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
return
|
return
|
||||||
|
@ -458,6 +466,8 @@ class BertGen(ViewGen):
|
||||||
:param lX: dict {lang: indexed documents}
|
:param lX: dict {lang: indexed documents}
|
||||||
:return: documents projected to the common latent space.
|
:return: documents projected to the common latent space.
|
||||||
"""
|
"""
|
||||||
|
if self.zero_shot:
|
||||||
|
lX = self.zero_shot_experiments(lX)
|
||||||
data = tokenize(lX, max_len=512)
|
data = tokenize(lX, max_len=512)
|
||||||
self.model.to('cuda' if self.gpus else 'cpu')
|
self.model.to('cuda' if self.gpus else 'cpu')
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -468,6 +478,12 @@ class BertGen(ViewGen):
|
||||||
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
def zero_shot_experiments(self, lX):
|
||||||
|
for lang in sorted(lX.keys()):
|
||||||
|
if lang not in self.train_langs:
|
||||||
|
lX.pop(lang)
|
||||||
|
return lX
|
||||||
|
|
||||||
def set_zero_shot(self, val: bool):
|
def set_zero_shot(self, val: bool):
|
||||||
self.zero_shot = val
|
self.zero_shot = val
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue