gfun_multimodal/gfun/generalizedFunnelling.py

183 lines
5.5 KiB
Python

import os
import sys
sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from vgfs.transformerGen import TransformerGen
from vgfs.vanillaFun import VanillaFunGen
from vgfs.wceGen import WceGen
# TODO: save and load gfun model
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
transformer,
langs,
embed_dir,
n_jobs,
batch_size,
max_length,
lr,
epochs,
patience,
evaluate_step,
transformer_name,
):
# Forcing VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params
self.transformer_name = transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = "mean"
self.init()
def init(self):
print("[Init GeneralizedFunnelling]")
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
first_tier_parameters=None,
n_jobs=self.n_jobs,
)
self.first_tier_learners.append(fun)
if self.multilingual_vgf:
multilingual_vgf = MultilingualGen(
embed_dir=self.embed_dir,
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=True,
)
self.first_tier_learners.append(multilingual_vgf)
if self.wce_vgf:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
transformer_vgf = TransformerGen(
model_name=self.transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device="cuda",
print_steps=50,
probabilistic=True,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
)
self.first_tier_learners.append(transformer_vgf)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(),
n_jobs=self.n_jobs,
)
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
projections = []
print("- fitting first tier learners")
for vgf in self.first_tier_learners:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections)
self.metaclassifier.fit(agg, lY)
return self
def transform(self, lX):
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
return l_out
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
else:
raise NotImplementedError
return aggregated
def _aggregate_mean(self, first_tier_projections):
# TODO: deafult dict for one-liner?
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
# Computing mean
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
return aggregated
def get_config(self):
from pprint import pprint
# TODO
print("[GeneralizedFunnelling config]")
print(f"- langs: {self.langs}")
print("-- vgfs:")
for vgf in self.first_tier_learners:
pprint(vgf.get_config())
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = "rbf"
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]