Implemented funnelling architecture
This commit is contained in:
parent
94866e5ad8
commit
a5af2134bf
|
|
@ -9,7 +9,7 @@ from time import time
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
OPTIMC = True # TODO
|
OPTIMC = False # TODO
|
||||||
N_JOBS = 8
|
N_JOBS = 8
|
||||||
print('Running refactored...')
|
print('Running refactored...')
|
||||||
|
|
||||||
|
|
@ -20,6 +20,7 @@ def main(args):
|
||||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||||
data = MultilingualDataset.load(_DATASET)
|
data = MultilingualDataset.load(_DATASET)
|
||||||
data.set_view(languages=['it', 'fr'])
|
data.set_view(languages=['it', 'fr'])
|
||||||
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
|
|
@ -53,8 +54,8 @@ def main(args):
|
||||||
# Init DocEmbedderList
|
# Init DocEmbedderList
|
||||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}]
|
meta_parameters = None if not OPTIMC else [{'C': [1, 1e3, 1e2, 1e1, 1e-1]}]
|
||||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf', C=meta_parameters),
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||||
meta_parameters=get_params(optimc=True))
|
meta_parameters=get_params(optimc=OPTIMC))
|
||||||
|
|
||||||
# Init Funnelling Architecture
|
# Init Funnelling Architecture
|
||||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
||||||
|
|
@ -71,7 +72,7 @@ def main(args):
|
||||||
print('\n[Testing Generalized Funnelling]')
|
print('\n[Testing Generalized Funnelling]')
|
||||||
time_te = time()
|
time_te = time()
|
||||||
ly_ = gfun.predict(lXte)
|
ly_ = gfun.predict(lXte)
|
||||||
l_eval = evaluate(ly_true=ly, ly_pred=ly_)
|
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||||
time_te = round(time() - time_te, 3)
|
time_te = round(time() - time_te, 3)
|
||||||
print(f'Testing completed in {time_te} seconds!')
|
print(f'Testing completed in {time_te} seconds!')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class VanillaFunGen(ViewGen):
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
print('# Fitting VanillaFunGen...')
|
print('# Fitting VanillaFunGen (X)...')
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.doc_projector.fit(lX, lY)
|
self.doc_projector.fit(lX, lY)
|
||||||
return self
|
return self
|
||||||
|
|
@ -85,7 +85,7 @@ class MuseGen(ViewGen):
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
print('# Fitting MuseGen...')
|
print('# Fitting MuseGen (M)...')
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
|
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
|
||||||
|
|
@ -120,7 +120,7 @@ class WordClassGen(ViewGen):
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
print('# Fitting WordClassGen...')
|
print('# Fitting WordClassGen (W)...')
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
wce = Parallel(n_jobs=self.n_jobs)(
|
wce = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
|
@ -207,7 +207,7 @@ class RecurrentGen(ViewGen):
|
||||||
:param ly:
|
:param ly:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
print('# Fitting RecurrentGen...')
|
print('# Fitting RecurrentGen (G)...')
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
checkpoint_callback=False)
|
checkpoint_callback=False)
|
||||||
|
|
@ -260,7 +260,7 @@ class BertGen(ViewGen):
|
||||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||||
|
|
||||||
def fit(self, lX, ly):
|
def fit(self, lX, ly):
|
||||||
print('# Fitting BertGen...')
|
print('# Fitting BertGen (M)...')
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue