fixed batcher
This commit is contained in:
parent
5906f85f33
commit
94bfe6a036
|
|
@ -301,15 +301,16 @@ class RecurrentEmbedder:
|
|||
self.test_each = test_each
|
||||
self.options = options
|
||||
self.seed = options.seed
|
||||
self.model_path = model_path
|
||||
self.is_trained = False
|
||||
|
||||
## INIT MODEL for training
|
||||
self.lXtr, self.lytr = self.multilingual_dataset.training(target_as_csr=True)
|
||||
self.lXte, self.lyte = self.multilingual_dataset.test(target_as_csr=True)
|
||||
self.nC = self.lyte[self.langs[0]].shape[1]
|
||||
lpretrained, lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
|
||||
lpretrained, self.lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
|
||||
self.multilingual_index = MultilingualIndex()
|
||||
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, lpretrained_vocabulary)
|
||||
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, self.lpretrained_vocabulary)
|
||||
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
||||
self.multilingual_index.embedding_matrices(lpretrained, self.supervised)
|
||||
|
||||
|
|
@ -324,12 +325,15 @@ class RecurrentEmbedder:
|
|||
self.lr_scheduler = StepLR(self.optim, step_size=25, gamma=0.5)
|
||||
self.early_stop = EarlyStopping(self.model, optimizer=self.optim, patience=self.patience,
|
||||
checkpoint=f'{self.checkpoint_dir}/gru_viewgen_-{get_file_name(self.options.dataset)}')
|
||||
|
||||
|
||||
# Init SVM in order to recast (vstacked) document embeddings to vectors of Posterior Probabilities
|
||||
self.posteriorEmbedder = MetaClassifier(
|
||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
|
||||
|
||||
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
||||
print('### Gated Recurrent Unit View Generator (G)')
|
||||
# self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
|
||||
# could be better to init model here at first .fit() call!
|
||||
if self.model is None:
|
||||
print('TODO: Init model!')
|
||||
|
|
@ -381,12 +385,14 @@ class RecurrentEmbedder:
|
|||
self.is_trained = True
|
||||
|
||||
# Generate document embeddings in order to fit an SVM to recast them as vector for Posterior Probabilities
|
||||
lX = self._get_doc_embeddings(lX)
|
||||
# lX = self._get_doc_embeddings(lX)
|
||||
lX = self._get_doc_embeddings(self.multilingual_index.l_devel_index())
|
||||
# Fit a ''multi-lingual'' SVM on the generated doc embeddings
|
||||
self.posteriorEmbedder.fit(lX, ly)
|
||||
return self
|
||||
|
||||
def transform(self, lX, batch_size=64):
|
||||
lX = self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
|
||||
lX = self._get_doc_embeddings(lX)
|
||||
return self.posteriorEmbedder.predict_proba(lX)
|
||||
|
||||
|
|
@ -397,28 +403,22 @@ class RecurrentEmbedder:
|
|||
def _get_doc_embeddings(self, lX, batch_size=64):
|
||||
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
|
||||
print('Generating document embeddings via GRU')
|
||||
lX = {}
|
||||
ly = {}
|
||||
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
|
||||
lpad=self.multilingual_index.l_pad())
|
||||
_lX = {}
|
||||
|
||||
l_devel_index = self.multilingual_index.l_devel_index()
|
||||
l_devel_target = self.multilingual_index.l_devel_target()
|
||||
|
||||
for idx, (batch, post, bert_emb, target, lang) in enumerate(
|
||||
batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
|
||||
if lang not in lX.keys():
|
||||
lX[lang] = self.model.get_embeddings(batch, lang)
|
||||
ly[lang] = target.cpu().detach().numpy()
|
||||
for idx, (batch, post, target, lang) in enumerate(batchify(lX, None, l_devel_target,
|
||||
batch_size, self.multilingual_index.l_pad())):
|
||||
if lang not in _lX.keys():
|
||||
_lX[lang] = self.model.get_embeddings(batch, lang)
|
||||
else:
|
||||
lX[lang] = np.concatenate((lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
|
||||
ly[lang] = np.concatenate((ly[lang], target.cpu().detach().numpy()), axis=0)
|
||||
_lX[lang] = np.concatenate((_lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
|
||||
|
||||
return lX
|
||||
return _lX
|
||||
|
||||
# loads the MUSE embeddings if requested, or returns empty dictionaries otherwise
|
||||
def _load_pretrained_embeddings(self, we_path, langs):
|
||||
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
|
||||
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
|
||||
lpretrained = load_muse_embeddings(we_path, langs, n_jobs=-1)
|
||||
lpretrained_vocabulary = {l: lpretrained[l].vocabulary() for l in langs}
|
||||
return lpretrained, lpretrained_vocabulary
|
||||
|
|
@ -703,14 +703,14 @@ class BatchGRU:
|
|||
self.batchsize = batchsize
|
||||
self.batches_per_epoch = batches_per_epoch
|
||||
self.languages = languages
|
||||
self.lpad=lpad
|
||||
self.max_pad_length=max_pad_length
|
||||
self.lpad = lpad
|
||||
self.max_pad_length = max_pad_length
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
self.offset = {lang: 0 for lang in self.languages}
|
||||
|
||||
def batchify(self, l_index, l_post, l_bert, llabels):
|
||||
def batchify(self, l_index, l_post, l_bert, llabels, extractor=False):
|
||||
langs = self.languages
|
||||
l_num_samples = {l:len(l_index[l]) for l in langs}
|
||||
|
||||
|
|
|
|||
|
|
@ -180,12 +180,27 @@ class MultilingualIndex:
|
|||
self.l_index[l] = Index(l_devel_raw[l], l_devel_target[l], l_test_raw[l], l)
|
||||
self.l_index[l].index(l_pretrained_vocabulary[l], l_analyzer[l], l_vocabulary[l])
|
||||
|
||||
def get_indexed(self, l_texts, pretrained_vocabulary=None):
|
||||
assert len(self.l_index) != 0, 'Cannot index data without first index call to multilingual index!'
|
||||
l_indexed = {}
|
||||
for l, texts in l_texts.items():
|
||||
if l in self.langs:
|
||||
word2index = self.l_index[l].word2index
|
||||
known_words = set(word2index.keys())
|
||||
if pretrained_vocabulary[l] is not None:
|
||||
known_words.update(pretrained_vocabulary[l])
|
||||
l_indexed[l] = index(texts,
|
||||
vocab=word2index,
|
||||
known_words=known_words,
|
||||
analyzer=self.l_vectorizer.get_analyzer(l),
|
||||
unk_index=word2index['UNKTOKEN'],
|
||||
out_of_vocabulary=dict())
|
||||
return l_indexed
|
||||
|
||||
def train_val_split(self, val_prop=0.2, max_val=2000, seed=42):
|
||||
for l,index in self.l_index.items():
|
||||
index.train_val_split(val_prop, max_val, seed=seed)
|
||||
|
||||
|
||||
|
||||
def embedding_matrices(self, lpretrained, supervised):
|
||||
lXtr = self.get_lXtr() if supervised else none_dict(self.langs)
|
||||
lYtr = self.l_train_target() if supervised else none_dict(self.langs)
|
||||
|
|
@ -385,7 +400,7 @@ class Batch:
|
|||
def init_offset(self):
|
||||
self.offset = {lang: 0 for lang in self.languages}
|
||||
|
||||
def batchify(self, l_index, l_post, l_bert, llabels): # TODO: add bert embedding here...
|
||||
def batchify(self, l_index, l_post, l_bert, llabels):
|
||||
langs = self.languages
|
||||
l_num_samples = {l:len(l_index[l]) for l in langs}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue