Implemented funnelling architecture
This commit is contained in:
parent
ae0ea1e68c
commit
0b54864514
|
|
@ -105,21 +105,25 @@ class RecurrentDataModule(pl.LightningDataModule):
|
|||
if stage == 'fit' or stage is None:
|
||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||
# Debug settings: reducing number of samples
|
||||
l_train_index = {l: train[:50] for l, train in l_train_index.items()}
|
||||
l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||
l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||
# Debug settings: reducing number of samples
|
||||
l_val_index = {l: train[:50] for l, train in l_val_index.items()}
|
||||
l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||
l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
if stage == 'test' or stage is None:
|
||||
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||
# Debug settings: reducing number of samples
|
||||
l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
|
|
@ -145,8 +149,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
if stage == 'fit' or stage is None:
|
||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
# l_train_raw = {l: train[:50] for l, train in l_train_raw.items()}
|
||||
# l_train_target = {l: target[:50] for l, target in l_train_target.items()}
|
||||
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
|
||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
|
|
@ -154,8 +158,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
|
||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
# l_val_raw = {l: train[:50] for l, train in l_val_raw.items()}
|
||||
# l_val_target = {l: target[:50] for l, target in l_val_target.items()}
|
||||
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
|
||||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
|
|
@ -163,6 +167,10 @@ class BertDataModule(RecurrentDataModule):
|
|||
|
||||
if stage == 'test' or stage is None:
|
||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
|
||||
l_test_index = self.tokenize(l_test_raw, max_len=self.max_len)
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
from models.learners import *
|
||||
from view_generators import VanillaFunGen
|
||||
from util.common import _normalize
|
||||
|
||||
|
||||
class DocEmbedderList:
|
||||
def __init__(self, embedder_list, probabilistic=True):
|
||||
"""
|
||||
Class that takes care of calling fit and transform function for every init embedder.
|
||||
:param embedder_list: list of embedders to be deployed
|
||||
:param probabilistic: whether to recast view generators output to vectors of posterior probabilities or not
|
||||
"""
|
||||
assert len(embedder_list) != 0, 'Embedder list cannot be empty!'
|
||||
self.embedders = embedder_list
|
||||
self.probabilistic = probabilistic
|
||||
if probabilistic:
|
||||
_tmp = []
|
||||
for embedder in self.embedders:
|
||||
if isinstance(embedder, VanillaFunGen):
|
||||
_tmp.append(embedder)
|
||||
else:
|
||||
_tmp.append(FeatureSet2Posteriors(embedder))
|
||||
self.embedders = _tmp
|
||||
|
||||
def fit(self, lX, ly):
|
||||
for embedder in self.embedders:
|
||||
embedder.fit(lX, ly)
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
langs = sorted(lX.keys())
|
||||
lZparts = {lang: None for lang in langs}
|
||||
|
||||
for embedder in self.embedders:
|
||||
lZ = embedder.transform(lX)
|
||||
for lang in langs:
|
||||
Z = lZ[lang]
|
||||
if lZparts[lang] is None:
|
||||
lZparts[lang] = Z
|
||||
else:
|
||||
lZparts[lang] += Z
|
||||
n_embedders = len(self.embedders)
|
||||
return {lang: lZparts[lang]/n_embedders for lang in langs}
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
||||
|
||||
class FeatureSet2Posteriors:
|
||||
def __init__(self, embedder, l2=True, n_jobs=-1):
|
||||
self.embedder = embedder
|
||||
self.l2 = l2
|
||||
self.n_jobs = n_jobs
|
||||
self.prob_classifier = MetaClassifier(
|
||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
lZ = self.embedder.fit_transform(lX, ly)
|
||||
self.prob_classifier.fit(lZ, ly)
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
lP = self.predict_proba(lX)
|
||||
lP = _normalize(lP, self.l2)
|
||||
return lP
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
||||
def predict(self, lX):
|
||||
lZ = self.embedder.transform(lX)
|
||||
return self.prob_classifier.predict(lZ)
|
||||
|
||||
def predict_proba(self, lX):
|
||||
lZ = self.embedder.transform(lX)
|
||||
return self.prob_classifier.predict_proba(lZ)
|
||||
|
||||
|
||||
class Funnelling:
|
||||
def __init__(self, first_tier: DocEmbedderList, n_jobs=-1):
|
||||
self.first_tier = first_tier
|
||||
self.meta = MetaClassifier(
|
||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs)
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, lX, ly):
|
||||
print('## Fitting first-tier learners!')
|
||||
lZ = self.first_tier.fit_transform(lX, ly)
|
||||
print('## Fitting meta-learner!')
|
||||
self.meta.fit(lZ, ly)
|
||||
|
||||
def predict(self, lX):
|
||||
lZ = self.first_tier.transform(lX)
|
||||
ly = self.meta.predict(lZ)
|
||||
return ly
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
from argparse import ArgumentParser
|
||||
from util.embeddings_manager import MuseLoader
|
||||
from view_generators import RecurrentGen, BertGen
|
||||
from funnelling import *
|
||||
from view_generators import *
|
||||
from data.dataset_builder import MultilingualDataset
|
||||
from util.common import MultilingualIndex
|
||||
from util.evaluation import evaluate
|
||||
from time import time
|
||||
|
||||
|
||||
|
|
@ -25,21 +26,38 @@ def main(args):
|
|||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||
|
||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||
# gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
||||
# nepochs=50, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
gFun = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
museEmbedder = MuseGen(muse_dir=EMBEDDINGS_PATH, n_jobs=N_JOBS)
|
||||
wceEmbedder = WordClassGen(n_jobs=N_JOBS)
|
||||
# rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
||||
# nepochs=250, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# bertEmbedder = BertGen(multilingualIndex, batch_size=4, nepochs=1, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
|
||||
docEmbedders = DocEmbedderList([museEmbedder, wceEmbedder])
|
||||
|
||||
gfun = Funnelling(first_tier=docEmbedders)
|
||||
|
||||
# Training ---------------------------------------
|
||||
print('\n[Training Generalized Funnelling]')
|
||||
time_init = time()
|
||||
gFun.fit(lX, ly)
|
||||
time_tr = time()
|
||||
gfun.fit(lX, ly)
|
||||
time_tr = round(time() - time_tr, 3)
|
||||
print(f'Training completed in {time_tr} seconds!')
|
||||
|
||||
# print('Projecting...')
|
||||
# y_ = gFun.transform(lX)
|
||||
# Testing ----------------------------------------
|
||||
print('\n[Testing Generalized Funnelling]')
|
||||
time_te = time()
|
||||
ly_ = gfun.predict(lXte)
|
||||
|
||||
train_time = round(time() - time_init, 3)
|
||||
exit(f'Executed! Training time: {train_time}!')
|
||||
l_eval = evaluate(ly_true=ly, ly_pred=ly_)
|
||||
print(l_eval)
|
||||
|
||||
time_te = round(time() - time_te, 3)
|
||||
print(f'Testing completed in {time_te} seconds!')
|
||||
|
||||
overall_time = round(time() - time_init, 3)
|
||||
exit(f'\nExecuted in: {overall_time } seconds!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sklearn.multiclass import OneVsRestClassifier
|
|||
from sklearn.model_selection import GridSearchCV
|
||||
from sklearn.svm import SVC
|
||||
from joblib import Parallel, delayed
|
||||
from util.standardizer import StandardizeTransformer
|
||||
|
||||
|
||||
def get_learner(calibrate=False, kernel='linear', C=1):
|
||||
|
|
@ -156,7 +157,6 @@ class MonolingualClassifier:
|
|||
self.model = GridSearchCV(self.model, param_grid=self.parameters, refit=True, cv=5, n_jobs=self.n_jobs,
|
||||
error_score=0, verbose=10)
|
||||
|
||||
# print(f'fitting: {self.model} on matrices of shape X={X.shape} Y={y.shape}')
|
||||
print(f'fitting: Mono-lingual Classifier on matrices of shape X={X.shape} Y={y.shape}')
|
||||
self.model.fit(X, y)
|
||||
if isinstance(self.model, GridSearchCV):
|
||||
|
|
@ -183,3 +183,40 @@ class MonolingualClassifier:
|
|||
|
||||
def best_params(self):
|
||||
return self.best_params_
|
||||
|
||||
|
||||
class MetaClassifier:
|
||||
|
||||
def __init__(self, meta_learner, meta_parameters=None, n_jobs=-1, standardize_range=None):
|
||||
self.n_jobs = n_jobs
|
||||
self.model = MonolingualClassifier(base_learner=meta_learner, parameters=meta_parameters, n_jobs=n_jobs)
|
||||
self.standardize_range = standardize_range
|
||||
|
||||
def fit(self, lZ, ly):
|
||||
tinit = time.time()
|
||||
Z, y = self.stack(lZ, ly)
|
||||
|
||||
self.standardizer = StandardizeTransformer(range=self.standardize_range)
|
||||
Z = self.standardizer.fit_transform(Z)
|
||||
|
||||
print('fitting the Z-space of shape={}'.format(Z.shape))
|
||||
self.model.fit(Z, y)
|
||||
self.time = time.time() - tinit
|
||||
|
||||
def stack(self, lZ, ly=None):
|
||||
langs = list(lZ.keys())
|
||||
Z = np.vstack([lZ[lang] for lang in langs])
|
||||
if ly is not None:
|
||||
y = np.vstack([ly[lang] for lang in langs])
|
||||
return Z, y
|
||||
else:
|
||||
return Z
|
||||
|
||||
def predict(self, lZ):
|
||||
lZ = _joblib_transform_multiling(self.standardizer.transform, lZ, n_jobs=self.n_jobs)
|
||||
return _joblib_transform_multiling(self.model.predict, lZ, n_jobs=self.n_jobs)
|
||||
|
||||
def predict_proba(self, lZ):
|
||||
lZ = _joblib_transform_multiling(self.standardizer.transform, lZ, n_jobs=self.n_jobs)
|
||||
return _joblib_transform_multiling(self.model.predict_proba, lZ, n_jobs=self.n_jobs)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ class BertModel(pl.LightningModule):
|
|||
def training_step(self, train_batch, batch_idx):
|
||||
X, y, _, batch_langs = train_batch
|
||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||
# y = y.type(torch.cuda.FloatTensor)
|
||||
y = y.type(torch.FloatTensor)
|
||||
y = y.to('cuda' if self.gpus else 'cpu')
|
||||
logits, _ = self.forward(X)
|
||||
|
|
@ -64,18 +63,6 @@ class BertModel(pl.LightningModule):
|
|||
lX, ly = self._reconstruct_dict(predictions, y, batch_langs)
|
||||
return {'loss': loss, 'pred': lX, 'target': ly}
|
||||
|
||||
def _reconstruct_dict(self, predictions, y, batch_langs):
|
||||
reconstructed_x = {lang: [] for lang in set(batch_langs)}
|
||||
reconstructed_y = {lang: [] for lang in set(batch_langs)}
|
||||
for i, pred in enumerate(predictions):
|
||||
reconstructed_x[batch_langs[i]].append(pred)
|
||||
reconstructed_y[batch_langs[i]].append(y[i])
|
||||
for k, v in reconstructed_x.items():
|
||||
reconstructed_x[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||
for k, v in reconstructed_y.items():
|
||||
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||
return reconstructed_x, reconstructed_y
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
langs = []
|
||||
for output in outputs:
|
||||
|
|
@ -114,7 +101,6 @@ class BertModel(pl.LightningModule):
|
|||
def validation_step(self, val_batch, batch_idx):
|
||||
X, y, _, batch_langs = val_batch
|
||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||
# y = y.type(torch.cuda.FloatTensor)
|
||||
y = y.type(torch.FloatTensor)
|
||||
y = y.to('cuda' if self.gpus else 'cpu')
|
||||
logits, _ = self.forward(X)
|
||||
|
|
@ -134,7 +120,6 @@ class BertModel(pl.LightningModule):
|
|||
def test_step(self, test_batch, batch_idx):
|
||||
X, y, _, batch_langs = test_batch
|
||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||
# y = y.type(torch.cuda.FloatTensor)
|
||||
y = y.type(torch.FloatTensor)
|
||||
y = y.to('cuda' if self.gpus else 'cpu')
|
||||
logits, _ = self.forward(X)
|
||||
|
|
@ -164,3 +149,16 @@ class BertModel(pl.LightningModule):
|
|||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
@staticmethod
|
||||
def _reconstruct_dict(predictions, y, batch_langs):
|
||||
reconstructed_x = {lang: [] for lang in set(batch_langs)}
|
||||
reconstructed_y = {lang: [] for lang in set(batch_langs)}
|
||||
for i, pred in enumerate(predictions):
|
||||
reconstructed_x[batch_langs[i]].append(pred)
|
||||
reconstructed_y[batch_langs[i]].append(y[i])
|
||||
for k, v in reconstructed_x.items():
|
||||
reconstructed_x[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||
for k, v in reconstructed_y.items():
|
||||
reconstructed_y[k] = torch.cat(v).view(-1, predictions.shape[1])
|
||||
return reconstructed_x, reconstructed_y
|
||||
|
|
|
|||
|
|
@ -164,15 +164,6 @@ class RecurrentModel(pl.LightningModule):
|
|||
re_lX = self._reconstruct_dict(predictions, ly)
|
||||
return {'loss': loss, 'pred': re_lX, 'target': ly}
|
||||
|
||||
def _reconstruct_dict(self, X, ly):
|
||||
reconstructed = {}
|
||||
_start = 0
|
||||
for lang in sorted(ly.keys()):
|
||||
lang_batchsize = len(ly[lang])
|
||||
reconstructed[lang] = X[_start:_start+lang_batchsize]
|
||||
_start += lang_batchsize
|
||||
return reconstructed
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
# outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize.
|
||||
# here we save epoch level metric values and compute them specifically for each language
|
||||
|
|
@ -265,3 +256,13 @@ class RecurrentModel(pl.LightningModule):
|
|||
optimizer = AdamW(self.parameters(), lr=1e-3)
|
||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.5)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
@staticmethod
|
||||
def _reconstruct_dict(X, ly):
|
||||
reconstructed = {}
|
||||
_start = 0
|
||||
for lang in sorted(ly.keys()):
|
||||
lang_batchsize = len(ly[lang])
|
||||
reconstructed[lang] = X[_start:_start+lang_batchsize]
|
||||
_start += lang_batchsize
|
||||
return reconstructed
|
||||
|
|
|
|||
|
|
@ -311,8 +311,8 @@ def index(data, vocab, known_words, analyzer, unk_index, out_of_vocabulary):
|
|||
unk_count = 0
|
||||
knw_count = 0
|
||||
out_count = 0
|
||||
pbar = tqdm(data, desc=f'indexing')
|
||||
for text in pbar:
|
||||
# pbar = tqdm(data, desc=f'indexing')
|
||||
for text in data:
|
||||
words = analyzer(text)
|
||||
index = []
|
||||
for word in words:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class StandardizeTransformer:
|
||||
def __init__(self, axis=0, range=None):
|
||||
"""
|
||||
|
||||
:param axis:
|
||||
:param range:
|
||||
"""
|
||||
assert range is None or isinstance(range, slice), 'wrong format for range, should either be None or a slice'
|
||||
self.axis = axis
|
||||
self.yetfit = False
|
||||
self.range = range
|
||||
|
||||
def fit(self, X):
|
||||
print('Applying z-score standardization...')
|
||||
std=np.std(X, axis=self.axis, ddof=1)
|
||||
self.std = np.clip(std, 1e-5, None)
|
||||
self.mean = np.mean(X, axis=self.axis)
|
||||
if self.range is not None:
|
||||
ones = np.ones_like(self.std)
|
||||
zeros = np.zeros_like(self.mean)
|
||||
ones[self.range] = self.std[self.range]
|
||||
zeros[self.range] = self.mean[self.range]
|
||||
self.std = ones
|
||||
self.mean = zeros
|
||||
self.yetfit=True
|
||||
return self
|
||||
|
||||
def transform(self, X):
|
||||
if not self.yetfit: 'transform called before fit'
|
||||
return (X - self.mean) / self.std
|
||||
|
||||
def fit_transform(self, X):
|
||||
return self.fit(X).transform(X)
|
||||
|
|
@ -55,6 +55,7 @@ class VanillaFunGen(ViewGen):
|
|||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
|
||||
def fit(self, lX, lY):
|
||||
print('# Fitting VanillaFunGen...')
|
||||
lX = self.vectorizer.fit_transform(lX)
|
||||
self.doc_projector.fit(lX, lY)
|
||||
return self
|
||||
|
|
@ -84,6 +85,7 @@ class MuseGen(ViewGen):
|
|||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
print('# Fitting MuseGen...')
|
||||
self.vectorizer.fit(lX)
|
||||
self.langs = sorted(lX.keys())
|
||||
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
|
||||
|
|
@ -105,7 +107,6 @@ class MuseGen(ViewGen):
|
|||
|
||||
|
||||
class WordClassGen(ViewGen):
|
||||
|
||||
def __init__(self, n_jobs=-1):
|
||||
"""
|
||||
generates document representation via Word-Class-Embeddings.
|
||||
|
|
@ -119,6 +120,7 @@ class WordClassGen(ViewGen):
|
|||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
print('# Fitting WordClassGen...')
|
||||
lX = self.vectorizer.fit_transform(lX)
|
||||
self.langs = sorted(lX.keys())
|
||||
wce = Parallel(n_jobs=self.n_jobs)(
|
||||
|
|
@ -171,7 +173,7 @@ class RecurrentGen(ViewGen):
|
|||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
||||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn_dev', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
|
||||
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
||||
|
||||
def _init_model(self):
|
||||
|
|
@ -205,6 +207,7 @@ class RecurrentGen(ViewGen):
|
|||
:param ly:
|
||||
:return:
|
||||
"""
|
||||
print('# Fitting RecurrentGen...')
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
checkpoint_callback=False)
|
||||
|
|
@ -241,7 +244,6 @@ class RecurrentGen(ViewGen):
|
|||
|
||||
|
||||
class BertGen(ViewGen):
|
||||
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
||||
super().__init__()
|
||||
self.multilingualIndex = multilingualIndex
|
||||
|
|
@ -251,13 +253,14 @@ class BertGen(ViewGen):
|
|||
self.n_jobs = n_jobs
|
||||
self.stored_path = stored_path
|
||||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert_dev', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
print('# Fitting BertGen...')
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
|
|
@ -272,11 +275,14 @@ class BertGen(ViewGen):
|
|||
self.model.to('cuda' if self.gpus else 'cpu')
|
||||
self.model.eval()
|
||||
time_init = time()
|
||||
l_emebds = self.model.encode(data)
|
||||
pass
|
||||
l_emebds = self.model.encode(data) # TODO
|
||||
transform_time = round(time() - time_init, 3)
|
||||
print(f'Executed! Transform took: {transform_time}')
|
||||
exit('BERT VIEWGEN TRANSFORM NOT IMPLEMENTED!')
|
||||
return l_emebds
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
||||
pass
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue