From 94bfe6a0369c272a968204780054a99fbb795516 Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 27 Oct 2020 15:08:39 +0100 Subject: [PATCH] fixed batcher --- src/learning/transformers.py | 40 ++++++++++++++++++------------------ src/util/common.py | 21 ++++++++++++++++--- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/learning/transformers.py b/src/learning/transformers.py index 0032460..75d9888 100644 --- a/src/learning/transformers.py +++ b/src/learning/transformers.py @@ -301,15 +301,16 @@ class RecurrentEmbedder: self.test_each = test_each self.options = options self.seed = options.seed + self.model_path = model_path self.is_trained = False ## INIT MODEL for training self.lXtr, self.lytr = self.multilingual_dataset.training(target_as_csr=True) self.lXte, self.lyte = self.multilingual_dataset.test(target_as_csr=True) self.nC = self.lyte[self.langs[0]].shape[1] - lpretrained, lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs) + lpretrained, self.lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs) self.multilingual_index = MultilingualIndex() - self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, lpretrained_vocabulary) + self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, self.lpretrained_vocabulary) self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed) self.multilingual_index.embedding_matrices(lpretrained, self.supervised) @@ -324,12 +325,15 @@ class RecurrentEmbedder: self.lr_scheduler = StepLR(self.optim, step_size=25, gamma=0.5) self.early_stop = EarlyStopping(self.model, optimizer=self.optim, patience=self.patience, checkpoint=f'{self.checkpoint_dir}/gru_viewgen_-{get_file_name(self.options.dataset)}') + + # Init SVM in order to recast (vstacked) document embeddings to vectors of Posterior Probabilities self.posteriorEmbedder = MetaClassifier( SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs) def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1): print('### Gated Recurrent Unit View Generator (G)') + # self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary) # could be better to init model here at first .fit() call! if self.model is None: print('TODO: Init model!') @@ -381,12 +385,14 @@ class RecurrentEmbedder: self.is_trained = True # Generate document embeddings in order to fit an SVM to recast them as vector for Posterior Probabilities - lX = self._get_doc_embeddings(lX) + # lX = self._get_doc_embeddings(lX) + lX = self._get_doc_embeddings(self.multilingual_index.l_devel_index()) # Fit a ''multi-lingual'' SVM on the generated doc embeddings self.posteriorEmbedder.fit(lX, ly) return self def transform(self, lX, batch_size=64): + lX = self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary) lX = self._get_doc_embeddings(lX) return self.posteriorEmbedder.predict_proba(lX) @@ -397,28 +403,22 @@ class RecurrentEmbedder: def _get_doc_embeddings(self, lX, batch_size=64): assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!' print('Generating document embeddings via GRU') - lX = {} - ly = {} - batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs, - lpad=self.multilingual_index.l_pad()) + _lX = {} - l_devel_index = self.multilingual_index.l_devel_index() l_devel_target = self.multilingual_index.l_devel_target() - for idx, (batch, post, bert_emb, target, lang) in enumerate( - batcher_transform.batchify(l_devel_index, None, None, l_devel_target)): - if lang not in lX.keys(): - lX[lang] = self.model.get_embeddings(batch, lang) - ly[lang] = target.cpu().detach().numpy() + for idx, (batch, post, target, lang) in enumerate(batchify(lX, None, l_devel_target, + batch_size, self.multilingual_index.l_pad())): + if lang not in _lX.keys(): + _lX[lang] = self.model.get_embeddings(batch, lang) else: - lX[lang] = np.concatenate((lX[lang], self.model.get_embeddings(batch, lang)), axis=0) - ly[lang] = np.concatenate((ly[lang], target.cpu().detach().numpy()), axis=0) + _lX[lang] = np.concatenate((_lX[lang], self.model.get_embeddings(batch, lang)), axis=0) - return lX + return _lX # loads the MUSE embeddings if requested, or returns empty dictionaries otherwise def _load_pretrained_embeddings(self, we_path, langs): - lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ? + lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ? lpretrained = load_muse_embeddings(we_path, langs, n_jobs=-1) lpretrained_vocabulary = {l: lpretrained[l].vocabulary() for l in langs} return lpretrained, lpretrained_vocabulary @@ -703,14 +703,14 @@ class BatchGRU: self.batchsize = batchsize self.batches_per_epoch = batches_per_epoch self.languages = languages - self.lpad=lpad - self.max_pad_length=max_pad_length + self.lpad = lpad + self.max_pad_length = max_pad_length self.init_offset() def init_offset(self): self.offset = {lang: 0 for lang in self.languages} - def batchify(self, l_index, l_post, l_bert, llabels): + def batchify(self, l_index, l_post, l_bert, llabels, extractor=False): langs = self.languages l_num_samples = {l:len(l_index[l]) for l in langs} diff --git a/src/util/common.py b/src/util/common.py index 88134d3..219931a 100755 --- a/src/util/common.py +++ b/src/util/common.py @@ -180,12 +180,27 @@ class MultilingualIndex: self.l_index[l] = Index(l_devel_raw[l], l_devel_target[l], l_test_raw[l], l) self.l_index[l].index(l_pretrained_vocabulary[l], l_analyzer[l], l_vocabulary[l]) + def get_indexed(self, l_texts, pretrained_vocabulary=None): + assert len(self.l_index) != 0, 'Cannot index data without first index call to multilingual index!' + l_indexed = {} + for l, texts in l_texts.items(): + if l in self.langs: + word2index = self.l_index[l].word2index + known_words = set(word2index.keys()) + if pretrained_vocabulary[l] is not None: + known_words.update(pretrained_vocabulary[l]) + l_indexed[l] = index(texts, + vocab=word2index, + known_words=known_words, + analyzer=self.l_vectorizer.get_analyzer(l), + unk_index=word2index['UNKTOKEN'], + out_of_vocabulary=dict()) + return l_indexed + def train_val_split(self, val_prop=0.2, max_val=2000, seed=42): for l,index in self.l_index.items(): index.train_val_split(val_prop, max_val, seed=seed) - - def embedding_matrices(self, lpretrained, supervised): lXtr = self.get_lXtr() if supervised else none_dict(self.langs) lYtr = self.l_train_target() if supervised else none_dict(self.langs) @@ -385,7 +400,7 @@ class Batch: def init_offset(self): self.offset = {lang: 0 for lang in self.languages} - def batchify(self, l_index, l_post, l_bert, llabels): # TODO: add bert embedding here... + def batchify(self, l_index, l_post, l_bert, llabels): langs = self.languages l_num_samples = {l:len(l_index[l]) for l in langs}