fixed batcher

This commit is contained in:
andrea 2020-10-27 15:08:39 +01:00
parent 5906f85f33
commit 94bfe6a036
2 changed files with 38 additions and 23 deletions

View File

@ -301,15 +301,16 @@ class RecurrentEmbedder:
self.test_each = test_each
self.options = options
self.seed = options.seed
self.model_path = model_path
self.is_trained = False
## INIT MODEL for training
self.lXtr, self.lytr = self.multilingual_dataset.training(target_as_csr=True)
self.lXte, self.lyte = self.multilingual_dataset.test(target_as_csr=True)
self.nC = self.lyte[self.langs[0]].shape[1]
lpretrained, lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
lpretrained, self.lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
self.multilingual_index = MultilingualIndex()
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, lpretrained_vocabulary)
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, self.lpretrained_vocabulary)
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
self.multilingual_index.embedding_matrices(lpretrained, self.supervised)
@ -324,12 +325,15 @@ class RecurrentEmbedder:
self.lr_scheduler = StepLR(self.optim, step_size=25, gamma=0.5)
self.early_stop = EarlyStopping(self.model, optimizer=self.optim, patience=self.patience,
checkpoint=f'{self.checkpoint_dir}/gru_viewgen_-{get_file_name(self.options.dataset)}')
# Init SVM in order to recast (vstacked) document embeddings to vectors of Posterior Probabilities
self.posteriorEmbedder = MetaClassifier(
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
print('### Gated Recurrent Unit View Generator (G)')
# self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
# could be better to init model here at first .fit() call!
if self.model is None:
print('TODO: Init model!')
@ -381,12 +385,14 @@ class RecurrentEmbedder:
self.is_trained = True
# Generate document embeddings in order to fit an SVM to recast them as vector for Posterior Probabilities
lX = self._get_doc_embeddings(lX)
# lX = self._get_doc_embeddings(lX)
lX = self._get_doc_embeddings(self.multilingual_index.l_devel_index())
# Fit a ''multi-lingual'' SVM on the generated doc embeddings
self.posteriorEmbedder.fit(lX, ly)
return self
def transform(self, lX, batch_size=64):
lX = self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
lX = self._get_doc_embeddings(lX)
return self.posteriorEmbedder.predict_proba(lX)
@ -397,28 +403,22 @@ class RecurrentEmbedder:
def _get_doc_embeddings(self, lX, batch_size=64):
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
print('Generating document embeddings via GRU')
lX = {}
ly = {}
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
lpad=self.multilingual_index.l_pad())
_lX = {}
l_devel_index = self.multilingual_index.l_devel_index()
l_devel_target = self.multilingual_index.l_devel_target()
for idx, (batch, post, bert_emb, target, lang) in enumerate(
batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
if lang not in lX.keys():
lX[lang] = self.model.get_embeddings(batch, lang)
ly[lang] = target.cpu().detach().numpy()
for idx, (batch, post, target, lang) in enumerate(batchify(lX, None, l_devel_target,
batch_size, self.multilingual_index.l_pad())):
if lang not in _lX.keys():
_lX[lang] = self.model.get_embeddings(batch, lang)
else:
lX[lang] = np.concatenate((lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
ly[lang] = np.concatenate((ly[lang], target.cpu().detach().numpy()), axis=0)
_lX[lang] = np.concatenate((_lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
return lX
return _lX
# loads the MUSE embeddings if requested, or returns empty dictionaries otherwise
def _load_pretrained_embeddings(self, we_path, langs):
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
lpretrained = load_muse_embeddings(we_path, langs, n_jobs=-1)
lpretrained_vocabulary = {l: lpretrained[l].vocabulary() for l in langs}
return lpretrained, lpretrained_vocabulary
@ -703,14 +703,14 @@ class BatchGRU:
self.batchsize = batchsize
self.batches_per_epoch = batches_per_epoch
self.languages = languages
self.lpad=lpad
self.max_pad_length=max_pad_length
self.lpad = lpad
self.max_pad_length = max_pad_length
self.init_offset()
def init_offset(self):
self.offset = {lang: 0 for lang in self.languages}
def batchify(self, l_index, l_post, l_bert, llabels):
def batchify(self, l_index, l_post, l_bert, llabels, extractor=False):
langs = self.languages
l_num_samples = {l:len(l_index[l]) for l in langs}

View File

@ -180,12 +180,27 @@ class MultilingualIndex:
self.l_index[l] = Index(l_devel_raw[l], l_devel_target[l], l_test_raw[l], l)
self.l_index[l].index(l_pretrained_vocabulary[l], l_analyzer[l], l_vocabulary[l])
def get_indexed(self, l_texts, pretrained_vocabulary=None):
assert len(self.l_index) != 0, 'Cannot index data without first index call to multilingual index!'
l_indexed = {}
for l, texts in l_texts.items():
if l in self.langs:
word2index = self.l_index[l].word2index
known_words = set(word2index.keys())
if pretrained_vocabulary[l] is not None:
known_words.update(pretrained_vocabulary[l])
l_indexed[l] = index(texts,
vocab=word2index,
known_words=known_words,
analyzer=self.l_vectorizer.get_analyzer(l),
unk_index=word2index['UNKTOKEN'],
out_of_vocabulary=dict())
return l_indexed
def train_val_split(self, val_prop=0.2, max_val=2000, seed=42):
for l,index in self.l_index.items():
index.train_val_split(val_prop, max_val, seed=seed)
def embedding_matrices(self, lpretrained, supervised):
lXtr = self.get_lXtr() if supervised else none_dict(self.langs)
lYtr = self.l_train_target() if supervised else none_dict(self.langs)
@ -385,7 +400,7 @@ class Batch:
def init_offset(self):
self.offset = {lang: 0 for lang in self.languages}
def batchify(self, l_index, l_post, l_bert, llabels): # TODO: add bert embedding here...
def batchify(self, l_index, l_post, l_bert, llabels):
langs = self.languages
l_num_samples = {l:len(l_index[l]) for l in langs}