sketched out documentation
This commit is contained in:
parent
2a8075bbc2
commit
30d2be245c
|
|
@ -88,14 +88,21 @@ class RecurrentDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
class RecurrentDataModule(pl.LightningDataModule):
|
class RecurrentDataModule(pl.LightningDataModule):
|
||||||
def __init__(self, multilingualIndex, batchsize=64):
|
|
||||||
"""
|
"""
|
||||||
Pytorch-lightning DataModule: https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
Pytorch Lightning Datamodule to be deployed with RecurrentGen.
|
||||||
:param multilingualIndex:
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
:param batchsize:
|
"""
|
||||||
|
def __init__(self, multilingualIndex, batchsize=64, n_jobs=-1):
|
||||||
|
"""
|
||||||
|
Init RecurrentDataModule.
|
||||||
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
indexed by language code.
|
||||||
|
:param batchsize: int, number of sample per batch.
|
||||||
|
:param n_jobs: int, number of concurrent workers to be deployed (i.e., parallelizing data loading).
|
||||||
"""
|
"""
|
||||||
self.multilingualIndex = multilingualIndex
|
self.multilingualIndex = multilingualIndex
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
|
self.n_jobs = n_jobs
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def prepare_data(self, *args, **kwargs):
|
def prepare_data(self, *args, **kwargs):
|
||||||
|
|
@ -128,15 +135,15 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=self.n_jobs,
|
||||||
collate_fn=self.training_dataset.collate_fn)
|
collate_fn=self.training_dataset.collate_fn)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=self.n_jobs,
|
||||||
collate_fn=self.val_dataset.collate_fn)
|
collate_fn=self.val_dataset.collate_fn)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(self.test_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
return DataLoader(self.test_dataset, batch_size=self.batchsize, num_workers=self.n_jobs,
|
||||||
collate_fn=self.test_dataset.collate_fn)
|
collate_fn=self.test_dataset.collate_fn)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -156,7 +163,18 @@ def tokenize(l_raw, max_len):
|
||||||
|
|
||||||
|
|
||||||
class BertDataModule(RecurrentDataModule):
|
class BertDataModule(RecurrentDataModule):
|
||||||
|
"""
|
||||||
|
Pytorch Lightning Datamodule to be deployed with BertGen.
|
||||||
|
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||||
|
"""
|
||||||
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
def __init__(self, multilingualIndex, batchsize=64, max_len=512):
|
||||||
|
"""
|
||||||
|
Init BertDataModule.
|
||||||
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
indexed by language code.
|
||||||
|
:param batchsize: int, number of sample per batch.
|
||||||
|
:param max_len: int, max number of token per document. Absolute cap is 512.
|
||||||
|
"""
|
||||||
super().__init__(multilingualIndex, batchsize)
|
super().__init__(multilingualIndex, batchsize)
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,13 @@ from view_generators import VanillaFunGen
|
||||||
|
|
||||||
|
|
||||||
class DocEmbedderList:
|
class DocEmbedderList:
|
||||||
|
"""
|
||||||
|
Class that takes care of calling fit and transform function for every init embedder. Every ViewGenerator should be
|
||||||
|
contained by this class in order to seamlessly train the overall architecture.
|
||||||
|
"""
|
||||||
def __init__(self, embedder_list, probabilistic=True):
|
def __init__(self, embedder_list, probabilistic=True):
|
||||||
"""
|
"""
|
||||||
Class that takes care of calling fit and transform function for every init embedder.
|
Init the DocEmbedderList.
|
||||||
:param embedder_list: list of embedders to be deployed
|
:param embedder_list: list of embedders to be deployed
|
||||||
:param probabilistic: whether to recast view generators output to vectors of posterior probabilities or not
|
:param probabilistic: whether to recast view generators output to vectors of posterior probabilities or not
|
||||||
"""
|
"""
|
||||||
|
|
@ -23,11 +27,22 @@ class DocEmbedderList:
|
||||||
self.embedders = _tmp
|
self.embedders = _tmp
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
"""
|
||||||
|
Fit all the ViewGenerators contained by DocEmbedderList.
|
||||||
|
:param lX:
|
||||||
|
:param ly:
|
||||||
|
:return: self
|
||||||
|
"""
|
||||||
for embedder in self.embedders:
|
for embedder in self.embedders:
|
||||||
embedder.fit(lX, ly)
|
embedder.fit(lX, ly)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
"""
|
||||||
|
Project documents by means of every ViewGenerators. Projections are then averaged together and returned.
|
||||||
|
:param lX:
|
||||||
|
:return: common latent space (averaged).
|
||||||
|
"""
|
||||||
langs = sorted(lX.keys())
|
langs = sorted(lX.keys())
|
||||||
lZparts = {lang: None for lang in langs}
|
lZparts = {lang: None for lang in langs}
|
||||||
|
|
||||||
|
|
@ -40,14 +55,24 @@ class DocEmbedderList:
|
||||||
else:
|
else:
|
||||||
lZparts[lang] += Z
|
lZparts[lang] += Z
|
||||||
n_embedders = len(self.embedders)
|
n_embedders = len(self.embedders)
|
||||||
return {lang: lZparts[lang]/n_embedders for lang in langs}
|
return {lang: lZparts[lang]/n_embedders for lang in langs} # Averaging feature spaces
|
||||||
|
|
||||||
def fit_transform(self, lX, ly):
|
def fit_transform(self, lX, ly):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSet2Posteriors:
|
class FeatureSet2Posteriors:
|
||||||
|
"""
|
||||||
|
Takes care of recasting features outputted by the embedders to vecotrs of posterior probabilities by means of
|
||||||
|
a multiclass SVM.
|
||||||
|
"""
|
||||||
def __init__(self, embedder, l2=True, n_jobs=-1):
|
def __init__(self, embedder, l2=True, n_jobs=-1):
|
||||||
|
"""
|
||||||
|
Init the class.
|
||||||
|
:param embedder: ViewGen, view generators which does not natively outputs posterior probabilities.
|
||||||
|
:param l2: bool, whether to apply or not L2 normalization to the projection
|
||||||
|
:param n_jobs: int, number of concurrent workers.
|
||||||
|
"""
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
self.l2 = l2
|
self.l2 = l2
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
|
@ -77,6 +102,11 @@ class FeatureSet2Posteriors:
|
||||||
|
|
||||||
|
|
||||||
class Funnelling:
|
class Funnelling:
|
||||||
|
"""
|
||||||
|
Funnelling Architecture. It is composed by two tiers. The first-tier is a set of heterogeneous document embedders.
|
||||||
|
The second-tier (i.e., the metaclassifier), operates the classification of the common latent space computed by
|
||||||
|
the first-tier learners.
|
||||||
|
"""
|
||||||
def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1):
|
def __init__(self, first_tier: DocEmbedderList, meta_classifier: MetaClassifier, n_jobs=-1):
|
||||||
self.first_tier = first_tier
|
self.first_tier = first_tier
|
||||||
self.meta = meta_classifier
|
self.meta = meta_classifier
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ def main(args):
|
||||||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=args.muse_dir)
|
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=args.muse_dir)
|
||||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||||
|
|
||||||
|
# Init ViewGenerators and append them to embedder_list
|
||||||
embedder_list = []
|
embedder_list = []
|
||||||
if args.post_embedder:
|
if args.post_embedder:
|
||||||
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=args.n_jobs)
|
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=args.n_jobs)
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,10 @@ from util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||||
|
|
||||||
|
|
||||||
class ViewGen(ABC):
|
class ViewGen(ABC):
|
||||||
|
"""
|
||||||
|
Abstract class for ViewGenerators implementations. Every ViewGen should implement these three methods in order to
|
||||||
|
be seamlessly integrated in the overall architecture.
|
||||||
|
"""
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
pass
|
pass
|
||||||
|
|
@ -44,9 +48,13 @@ class ViewGen(ABC):
|
||||||
|
|
||||||
|
|
||||||
class VanillaFunGen(ViewGen):
|
class VanillaFunGen(ViewGen):
|
||||||
|
"""
|
||||||
|
View Generator (x): original funnelling architecture proposed by Moreo, Esuli and
|
||||||
|
Sebastiani in DOI: https://doi.org/10.1145/3326065
|
||||||
|
"""
|
||||||
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1):
|
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Original funnelling architecture proposed by Moreo, Esuli and Sebastiani in DOI: https://doi.org/10.1145/3326065
|
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
|
||||||
:param base_learner: naive monolingual learners to be deployed as first-tier learners. Should be able to
|
:param base_learner: naive monolingual learners to be deployed as first-tier learners. Should be able to
|
||||||
return posterior probabilities.
|
return posterior probabilities.
|
||||||
:param base_learner:
|
:param base_learner:
|
||||||
|
|
@ -68,11 +76,10 @@ class VanillaFunGen(ViewGen):
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
"""
|
"""
|
||||||
(1) Vectorize documents
|
(1) Vectorize documents; (2) Project them according to the learners SVMs, finally (3) Apply L2 normalization
|
||||||
(2) Project them according to the learners SVMs
|
to the projection and returns it.
|
||||||
(3) Apply L2 normalization to the projection
|
:param lX: dict {lang: indexed documents}
|
||||||
:param lX:
|
:return: document projection to the common latent space.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
lZ = self.doc_projector.predict_proba(lX)
|
lZ = self.doc_projector.predict_proba(lX)
|
||||||
|
|
@ -84,10 +91,13 @@ class VanillaFunGen(ViewGen):
|
||||||
|
|
||||||
|
|
||||||
class MuseGen(ViewGen):
|
class MuseGen(ViewGen):
|
||||||
|
"""
|
||||||
|
View Generator (m): generates document representation via MUSE embeddings (Fasttext multilingual word
|
||||||
|
embeddings). Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
||||||
|
"""
|
||||||
def __init__(self, muse_dir='../embeddings', n_jobs=-1):
|
def __init__(self, muse_dir='../embeddings', n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
generates document representation via MUSE embeddings (Fasttext multilingual word
|
Init the MuseGen.
|
||||||
embeddings). Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
|
||||||
:param muse_dir: string, path to folder containing muse embeddings
|
:param muse_dir: string, path to folder containing muse embeddings
|
||||||
:param n_jobs: int, number of concurrent workers
|
:param n_jobs: int, number of concurrent workers
|
||||||
"""
|
"""
|
||||||
|
|
@ -99,6 +109,12 @@ class MuseGen(ViewGen):
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
"""
|
||||||
|
(1) Vectorize documents; (2) Load muse embeddings for words encountered while vectorizing.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:param ly: dict {lang: target vectors}
|
||||||
|
:return: self.
|
||||||
|
"""
|
||||||
print('# Fitting MuseGen (M)...')
|
print('# Fitting MuseGen (M)...')
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
|
|
@ -109,6 +125,12 @@ class MuseGen(ViewGen):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
"""
|
||||||
|
(1) Vectorize documents; (2) computes the weighted sum of MUSE embeddings found at document level,
|
||||||
|
finally (3) Apply L2 normalization embedding and returns it.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:return: document projection to the common latent space.
|
||||||
|
"""
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
XdotMUSE = Parallel(n_jobs=self.n_jobs)(
|
XdotMUSE = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(XdotM)(lX[lang], self.lMuse[lang], sif=True) for lang in self.langs)
|
delayed(XdotM)(lX[lang], self.lMuse[lang], sif=True) for lang in self.langs)
|
||||||
|
|
@ -121,10 +143,13 @@ class MuseGen(ViewGen):
|
||||||
|
|
||||||
|
|
||||||
class WordClassGen(ViewGen):
|
class WordClassGen(ViewGen):
|
||||||
|
"""
|
||||||
|
View Generator (w): generates document representation via Word-Class-Embeddings.
|
||||||
|
Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
||||||
|
"""
|
||||||
def __init__(self, n_jobs=-1):
|
def __init__(self, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
generates document representation via Word-Class-Embeddings.
|
Init WordClassGen.
|
||||||
Document embeddings are obtained via weighted sum of document's constituent embeddings.
|
|
||||||
:param n_jobs: int, number of concurrent workers
|
:param n_jobs: int, number of concurrent workers
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -134,6 +159,12 @@ class WordClassGen(ViewGen):
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
"""
|
||||||
|
(1) Vectorize documents; (2) Load muse embeddings for words encountered while vectorizing.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:param ly: dict {lang: target vectors}
|
||||||
|
:return: self.
|
||||||
|
"""
|
||||||
print('# Fitting WordClassGen (W)...')
|
print('# Fitting WordClassGen (W)...')
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
|
|
@ -144,6 +175,12 @@ class WordClassGen(ViewGen):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
"""
|
||||||
|
(1) Vectorize documents; (2) computes the weighted sum of Word-Class Embeddings found at document level,
|
||||||
|
finally (3) Apply L2 normalization embedding and returns it.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:return: document projection to the common latent space.
|
||||||
|
"""
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
XdotWce = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in self.langs)
|
delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in self.langs)
|
||||||
|
|
@ -156,17 +193,28 @@ class WordClassGen(ViewGen):
|
||||||
|
|
||||||
|
|
||||||
class RecurrentGen(ViewGen):
|
class RecurrentGen(ViewGen):
|
||||||
|
"""
|
||||||
|
View Generator (G): generates document embedding by means of a Gated Recurrent Units. The model can be
|
||||||
|
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
|
||||||
|
Output dimension is (n_docs, 512). The training will happen end-to-end. At inference time, the model returns
|
||||||
|
the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard.
|
||||||
|
"""
|
||||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
||||||
gpus=0, n_jobs=-1, stored_path=None):
|
gpus=0, n_jobs=-1, stored_path=None):
|
||||||
"""
|
"""
|
||||||
generates document embedding by means of a Gated Recurrent Units. The model can be
|
Init RecurrentGen.
|
||||||
initialized with different (multilingual/aligned) word representations (e.g., MUSE, WCE, ecc.,).
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
Output dimension is (n_docs, 512).
|
indexed by language code.
|
||||||
:param multilingualIndex:
|
:param pretrained_embeddings: dict {lang: tensor of embeddings}, it contains the pretrained embeddings to use
|
||||||
:param pretrained_embeddings:
|
as embedding layer.
|
||||||
:param wce:
|
:param wce: Bool, whether to deploy Word-Class Embeddings (as proposed by A. Moreo). If True, supervised
|
||||||
:param gpus:
|
embeddings are concatenated to the deployed supervised embeddings. WCE dimensionality is equal to
|
||||||
:param n_jobs:
|
the number of target classes.
|
||||||
|
:param batch_size: int, number of samples in a batch.
|
||||||
|
:param nepochs: int, number of max epochs to train the model.
|
||||||
|
:param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu.
|
||||||
|
:param n_jobs: int, number of concurrent workers (i.e., parallelizing data loading).
|
||||||
|
:param stored_path: str, path to a pretrained model. If None the model will be trained from scratch.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multilingualIndex = multilingualIndex
|
self.multilingualIndex = multilingualIndex
|
||||||
|
|
@ -212,14 +260,15 @@ class RecurrentGen(ViewGen):
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
"""
|
"""
|
||||||
|
Train the Neural Network end-to-end.
|
||||||
lX and ly are not directly used. We rather get them from the multilingual index used in the instantiation
|
lX and ly are not directly used. We rather get them from the multilingual index used in the instantiation
|
||||||
of the Dataset object (RecurrentDataset) in the GfunDataModule class.
|
of the Dataset object (RecurrentDataset) in the GfunDataModule class.
|
||||||
:param lX:
|
:param lX: dict {lang: indexed documents}
|
||||||
:param ly:
|
:param ly: dict {lang: target vectors}
|
||||||
:return:
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting RecurrentGen (G)...')
|
print('# Fitting RecurrentGen (G)...')
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
checkpoint_callback=False)
|
checkpoint_callback=False)
|
||||||
|
|
||||||
|
|
@ -236,9 +285,9 @@ class RecurrentGen(ViewGen):
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
"""
|
"""
|
||||||
Project documents to the common latent space
|
Project documents to the common latent space. Output dimensionality is 512.
|
||||||
:param lX:
|
:param lX: dict {lang: indexed documents}
|
||||||
:return:
|
:return: documents projected to the common latent space.
|
||||||
"""
|
"""
|
||||||
l_pad = self.multilingualIndex.l_pad()
|
l_pad = self.multilingualIndex.l_pad()
|
||||||
data = self.multilingualIndex.l_devel_index()
|
data = self.multilingualIndex.l_devel_index()
|
||||||
|
|
@ -255,7 +304,22 @@ class RecurrentGen(ViewGen):
|
||||||
|
|
||||||
|
|
||||||
class BertGen(ViewGen):
|
class BertGen(ViewGen):
|
||||||
|
"""
|
||||||
|
View Generator (b): generates document embedding via Bert model. The training happens end-to-end.
|
||||||
|
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
|
||||||
|
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
|
||||||
|
"""
|
||||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
||||||
|
"""
|
||||||
|
Init Bert model
|
||||||
|
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||||
|
indexed by language code.
|
||||||
|
:param batch_size: int, number of samples per batch.
|
||||||
|
:param nepochs: int, number of max epochs to train the model.
|
||||||
|
:param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu.
|
||||||
|
:param n_jobs: int, number of concurrent workers.
|
||||||
|
:param stored_path: str, path to a pretrained model. If None the model will be trained from scratch.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multilingualIndex = multilingualIndex
|
self.multilingualIndex = multilingualIndex
|
||||||
self.nepochs = nepochs
|
self.nepochs = nepochs
|
||||||
|
|
@ -271,6 +335,14 @@ class BertGen(ViewGen):
|
||||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
|
"""
|
||||||
|
Train the Neural Network end-to-end.
|
||||||
|
lX and ly are not directly used. We rather get them from the multilingual index used in the instantiation
|
||||||
|
of the Dataset object (RecurrentDataset) in the GfunDataModule class.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:param ly: dict {lang: target vectors}
|
||||||
|
:return: self.
|
||||||
|
"""
|
||||||
print('# Fitting BertGen (M)...')
|
print('# Fitting BertGen (M)...')
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
|
|
@ -281,7 +353,11 @@ class BertGen(ViewGen):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
# lX is raw text data. It has to be first indexed via Bert Tokenizer.
|
"""
|
||||||
|
Project documents to the common latent space. Output dimensionality is 768.
|
||||||
|
:param lX: dict {lang: indexed documents}
|
||||||
|
:return: documents projected to the common latent space.
|
||||||
|
"""
|
||||||
data = self.multilingualIndex.l_devel_raw_index()
|
data = self.multilingualIndex.l_devel_raw_index()
|
||||||
data = tokenize(data, max_len=512)
|
data = tokenize(data, max_len=512)
|
||||||
self.model.to('cuda' if self.gpus else 'cpu')
|
self.model.to('cuda' if self.gpus else 'cpu')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue