fixed batcher
This commit is contained in:
parent
5906f85f33
commit
94bfe6a036
|
|
@ -301,15 +301,16 @@ class RecurrentEmbedder:
|
||||||
self.test_each = test_each
|
self.test_each = test_each
|
||||||
self.options = options
|
self.options = options
|
||||||
self.seed = options.seed
|
self.seed = options.seed
|
||||||
|
self.model_path = model_path
|
||||||
self.is_trained = False
|
self.is_trained = False
|
||||||
|
|
||||||
## INIT MODEL for training
|
## INIT MODEL for training
|
||||||
self.lXtr, self.lytr = self.multilingual_dataset.training(target_as_csr=True)
|
self.lXtr, self.lytr = self.multilingual_dataset.training(target_as_csr=True)
|
||||||
self.lXte, self.lyte = self.multilingual_dataset.test(target_as_csr=True)
|
self.lXte, self.lyte = self.multilingual_dataset.test(target_as_csr=True)
|
||||||
self.nC = self.lyte[self.langs[0]].shape[1]
|
self.nC = self.lyte[self.langs[0]].shape[1]
|
||||||
lpretrained, lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
|
lpretrained, self.lpretrained_vocabulary = self._load_pretrained_embeddings(self.we_path, self.langs)
|
||||||
self.multilingual_index = MultilingualIndex()
|
self.multilingual_index = MultilingualIndex()
|
||||||
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, lpretrained_vocabulary)
|
self.multilingual_index.index(self.lXtr, self.lytr, self.lXte, self.lpretrained_vocabulary)
|
||||||
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
self.multilingual_index.train_val_split(val_prop=0.2, max_val=2000, seed=self.seed)
|
||||||
self.multilingual_index.embedding_matrices(lpretrained, self.supervised)
|
self.multilingual_index.embedding_matrices(lpretrained, self.supervised)
|
||||||
|
|
||||||
|
|
@ -324,12 +325,15 @@ class RecurrentEmbedder:
|
||||||
self.lr_scheduler = StepLR(self.optim, step_size=25, gamma=0.5)
|
self.lr_scheduler = StepLR(self.optim, step_size=25, gamma=0.5)
|
||||||
self.early_stop = EarlyStopping(self.model, optimizer=self.optim, patience=self.patience,
|
self.early_stop = EarlyStopping(self.model, optimizer=self.optim, patience=self.patience,
|
||||||
checkpoint=f'{self.checkpoint_dir}/gru_viewgen_-{get_file_name(self.options.dataset)}')
|
checkpoint=f'{self.checkpoint_dir}/gru_viewgen_-{get_file_name(self.options.dataset)}')
|
||||||
|
|
||||||
|
|
||||||
# Init SVM in order to recast (vstacked) document embeddings to vectors of Posterior Probabilities
|
# Init SVM in order to recast (vstacked) document embeddings to vectors of Posterior Probabilities
|
||||||
self.posteriorEmbedder = MetaClassifier(
|
self.posteriorEmbedder = MetaClassifier(
|
||||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
|
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
|
||||||
|
|
||||||
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
||||||
print('### Gated Recurrent Unit View Generator (G)')
|
print('### Gated Recurrent Unit View Generator (G)')
|
||||||
|
# self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
|
||||||
# could be better to init model here at first .fit() call!
|
# could be better to init model here at first .fit() call!
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
print('TODO: Init model!')
|
print('TODO: Init model!')
|
||||||
|
|
@ -381,12 +385,14 @@ class RecurrentEmbedder:
|
||||||
self.is_trained = True
|
self.is_trained = True
|
||||||
|
|
||||||
# Generate document embeddings in order to fit an SVM to recast them as vector for Posterior Probabilities
|
# Generate document embeddings in order to fit an SVM to recast them as vector for Posterior Probabilities
|
||||||
lX = self._get_doc_embeddings(lX)
|
# lX = self._get_doc_embeddings(lX)
|
||||||
|
lX = self._get_doc_embeddings(self.multilingual_index.l_devel_index())
|
||||||
# Fit a ''multi-lingual'' SVM on the generated doc embeddings
|
# Fit a ''multi-lingual'' SVM on the generated doc embeddings
|
||||||
self.posteriorEmbedder.fit(lX, ly)
|
self.posteriorEmbedder.fit(lX, ly)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX, batch_size=64):
|
def transform(self, lX, batch_size=64):
|
||||||
|
lX = self.multilingual_index.get_indexed(lX, self.lpretrained_vocabulary)
|
||||||
lX = self._get_doc_embeddings(lX)
|
lX = self._get_doc_embeddings(lX)
|
||||||
return self.posteriorEmbedder.predict_proba(lX)
|
return self.posteriorEmbedder.predict_proba(lX)
|
||||||
|
|
||||||
|
|
@ -397,28 +403,22 @@ class RecurrentEmbedder:
|
||||||
def _get_doc_embeddings(self, lX, batch_size=64):
|
def _get_doc_embeddings(self, lX, batch_size=64):
|
||||||
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
|
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
|
||||||
print('Generating document embeddings via GRU')
|
print('Generating document embeddings via GRU')
|
||||||
lX = {}
|
_lX = {}
|
||||||
ly = {}
|
|
||||||
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
|
|
||||||
lpad=self.multilingual_index.l_pad())
|
|
||||||
|
|
||||||
l_devel_index = self.multilingual_index.l_devel_index()
|
|
||||||
l_devel_target = self.multilingual_index.l_devel_target()
|
l_devel_target = self.multilingual_index.l_devel_target()
|
||||||
|
|
||||||
for idx, (batch, post, bert_emb, target, lang) in enumerate(
|
for idx, (batch, post, target, lang) in enumerate(batchify(lX, None, l_devel_target,
|
||||||
batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
|
batch_size, self.multilingual_index.l_pad())):
|
||||||
if lang not in lX.keys():
|
if lang not in _lX.keys():
|
||||||
lX[lang] = self.model.get_embeddings(batch, lang)
|
_lX[lang] = self.model.get_embeddings(batch, lang)
|
||||||
ly[lang] = target.cpu().detach().numpy()
|
|
||||||
else:
|
else:
|
||||||
lX[lang] = np.concatenate((lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
|
_lX[lang] = np.concatenate((_lX[lang], self.model.get_embeddings(batch, lang)), axis=0)
|
||||||
ly[lang] = np.concatenate((ly[lang], target.cpu().detach().numpy()), axis=0)
|
|
||||||
|
|
||||||
return lX
|
return _lX
|
||||||
|
|
||||||
# loads the MUSE embeddings if requested, or returns empty dictionaries otherwise
|
# loads the MUSE embeddings if requested, or returns empty dictionaries otherwise
|
||||||
def _load_pretrained_embeddings(self, we_path, langs):
|
def _load_pretrained_embeddings(self, we_path, langs):
|
||||||
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
|
lpretrained = lpretrained_vocabulary = self._none_dict(langs) # TODO ?
|
||||||
lpretrained = load_muse_embeddings(we_path, langs, n_jobs=-1)
|
lpretrained = load_muse_embeddings(we_path, langs, n_jobs=-1)
|
||||||
lpretrained_vocabulary = {l: lpretrained[l].vocabulary() for l in langs}
|
lpretrained_vocabulary = {l: lpretrained[l].vocabulary() for l in langs}
|
||||||
return lpretrained, lpretrained_vocabulary
|
return lpretrained, lpretrained_vocabulary
|
||||||
|
|
@ -703,14 +703,14 @@ class BatchGRU:
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
self.batches_per_epoch = batches_per_epoch
|
self.batches_per_epoch = batches_per_epoch
|
||||||
self.languages = languages
|
self.languages = languages
|
||||||
self.lpad=lpad
|
self.lpad = lpad
|
||||||
self.max_pad_length=max_pad_length
|
self.max_pad_length = max_pad_length
|
||||||
self.init_offset()
|
self.init_offset()
|
||||||
|
|
||||||
def init_offset(self):
|
def init_offset(self):
|
||||||
self.offset = {lang: 0 for lang in self.languages}
|
self.offset = {lang: 0 for lang in self.languages}
|
||||||
|
|
||||||
def batchify(self, l_index, l_post, l_bert, llabels):
|
def batchify(self, l_index, l_post, l_bert, llabels, extractor=False):
|
||||||
langs = self.languages
|
langs = self.languages
|
||||||
l_num_samples = {l:len(l_index[l]) for l in langs}
|
l_num_samples = {l:len(l_index[l]) for l in langs}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -180,12 +180,27 @@ class MultilingualIndex:
|
||||||
self.l_index[l] = Index(l_devel_raw[l], l_devel_target[l], l_test_raw[l], l)
|
self.l_index[l] = Index(l_devel_raw[l], l_devel_target[l], l_test_raw[l], l)
|
||||||
self.l_index[l].index(l_pretrained_vocabulary[l], l_analyzer[l], l_vocabulary[l])
|
self.l_index[l].index(l_pretrained_vocabulary[l], l_analyzer[l], l_vocabulary[l])
|
||||||
|
|
||||||
|
def get_indexed(self, l_texts, pretrained_vocabulary=None):
|
||||||
|
assert len(self.l_index) != 0, 'Cannot index data without first index call to multilingual index!'
|
||||||
|
l_indexed = {}
|
||||||
|
for l, texts in l_texts.items():
|
||||||
|
if l in self.langs:
|
||||||
|
word2index = self.l_index[l].word2index
|
||||||
|
known_words = set(word2index.keys())
|
||||||
|
if pretrained_vocabulary[l] is not None:
|
||||||
|
known_words.update(pretrained_vocabulary[l])
|
||||||
|
l_indexed[l] = index(texts,
|
||||||
|
vocab=word2index,
|
||||||
|
known_words=known_words,
|
||||||
|
analyzer=self.l_vectorizer.get_analyzer(l),
|
||||||
|
unk_index=word2index['UNKTOKEN'],
|
||||||
|
out_of_vocabulary=dict())
|
||||||
|
return l_indexed
|
||||||
|
|
||||||
def train_val_split(self, val_prop=0.2, max_val=2000, seed=42):
|
def train_val_split(self, val_prop=0.2, max_val=2000, seed=42):
|
||||||
for l,index in self.l_index.items():
|
for l,index in self.l_index.items():
|
||||||
index.train_val_split(val_prop, max_val, seed=seed)
|
index.train_val_split(val_prop, max_val, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_matrices(self, lpretrained, supervised):
|
def embedding_matrices(self, lpretrained, supervised):
|
||||||
lXtr = self.get_lXtr() if supervised else none_dict(self.langs)
|
lXtr = self.get_lXtr() if supervised else none_dict(self.langs)
|
||||||
lYtr = self.l_train_target() if supervised else none_dict(self.langs)
|
lYtr = self.l_train_target() if supervised else none_dict(self.langs)
|
||||||
|
|
@ -385,7 +400,7 @@ class Batch:
|
||||||
def init_offset(self):
|
def init_offset(self):
|
||||||
self.offset = {lang: 0 for lang in self.languages}
|
self.offset = {lang: 0 for lang in self.languages}
|
||||||
|
|
||||||
def batchify(self, l_index, l_post, l_bert, llabels): # TODO: add bert embedding here...
|
def batchify(self, l_index, l_post, l_bert, llabels):
|
||||||
langs = self.languages
|
langs = self.languages
|
||||||
l_num_samples = {l:len(l_index[l]) for l in langs}
|
l_num_samples = {l:len(l_index[l]) for l in langs}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue