setting up zero-shot experiments (implemented for Recurrent and Bert but not tested for Bert)

This commit is contained in:
andrea 2021-02-02 17:13:31 +01:00
parent 6361a4eba0
commit ab3bacb29c
1 changed files with 16 additions and 0 deletions

View File

@ -363,6 +363,8 @@ class RecurrentGen(ViewGen):
:param lX: dict {lang: indexed documents} :param lX: dict {lang: indexed documents}
:return: documents projected to the common latent space. :return: documents projected to the common latent space.
""" """
if self.zero_shot:
lX = self.zero_shot_experiments(lX)
data = {} data = {}
for lang in lX.keys(): for lang in lX.keys():
indexed = index(data=lX[lang], indexed = index(data=lX[lang],
@ -381,6 +383,12 @@ class RecurrentGen(ViewGen):
def fit_transform(self, lX, ly): def fit_transform(self, lX, ly):
return self.fit(lX, ly).transform(lX) return self.fit(lX, ly).transform(lX)
def zero_shot_experiments(self, lX):
for lang in sorted(lX.keys()):
if lang not in self.train_langs:
lX.pop(lang)
return lX
def set_zero_shot(self, val: bool): def set_zero_shot(self, val: bool):
self.zero_shot = val self.zero_shot = val
return return
@ -458,6 +466,8 @@ class BertGen(ViewGen):
:param lX: dict {lang: indexed documents} :param lX: dict {lang: indexed documents}
:return: documents projected to the common latent space. :return: documents projected to the common latent space.
""" """
if self.zero_shot:
lX = self.zero_shot_experiments(lX)
data = tokenize(lX, max_len=512) data = tokenize(lX, max_len=512)
self.model.to('cuda' if self.gpus else 'cpu') self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
@ -468,6 +478,12 @@ class BertGen(ViewGen):
# we can assume that we have already indexed data for transform() since we are first calling fit() # we can assume that we have already indexed data for transform() since we are first calling fit()
return self.fit(lX, ly).transform(lX) return self.fit(lX, ly).transform(lX)
def zero_shot_experiments(self, lX):
for lang in sorted(lX.keys()):
if lang not in self.train_langs:
lX.pop(lang)
return lX
def set_zero_shot(self, val: bool): def set_zero_shot(self, val: bool):
self.zero_shot = val self.zero_shot = val
return return