gfun_multimodal/gfun/generalizedFunnelling.py

248 lines
8.1 KiB
Python

import os
import sys
sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from vgfs.transformerGen import TransformerGen
from vgfs.vanillaFun import VanillaFunGen
from vgfs.wceGen import WceGen
# TODO: save and load gfun model
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
transformer,
langs,
embed_dir,
n_jobs,
batch_size,
max_length,
lr,
epochs,
patience,
evaluate_step,
transformer_name,
optimc,
device,
load_trained,
):
# Forcing VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params ----------
self.transformer_name = transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = "mean"
self.load_trained = load_trained
self._init()
def _init(self):
print("[Init GeneralizedFunnelling]")
if self.load_trained:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
# TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
first_tier_parameters=None,
n_jobs=self.n_jobs,
)
self.first_tier_learners.append(fun)
if self.multilingual_vgf:
multilingual_vgf = MultilingualGen(
embed_dir=self.embed_dir,
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=True,
)
self.first_tier_learners.append(multilingual_vgf)
if self.wce_vgf:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
transformer_vgf = TransformerGen(
model_name=self.transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device="cuda",
print_steps=50,
probabilistic=True,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
)
self.first_tier_learners.append(transformer_vgf)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
return self
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
if self.load_trained:
print(f"- loaded trained model! Skipping training...")
load_only_first_tier = False # TODO
if load_only_first_tier:
projections = []
# TODO project, aggregate and fit the metaclassifier
return self
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
projections = []
print("- fitting first tier learners")
for vgf in self.first_tier_learners:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections)
self.metaclassifier.fit(agg, lY)
return self
def transform(self, lX):
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
return l_out
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
else:
raise NotImplementedError
return aggregated
def _aggregate_mean(self, first_tier_projections):
# TODO: deafult dict for one-liner?
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
# Computing mean
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
return aggregated
def get_config(self):
from pprint import pprint
# TODO
print("[GeneralizedFunnelling config]")
print(f"- langs: {self.langs}")
print("-- vgfs:")
for vgf in self.first_tier_learners:
pprint(vgf.get_config())
def save(self):
for vgf in self.first_tier_learners:
vgf.save_vgf()
# Saving metaclassifier
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
pickle.dump(self.metaclassifier, f)
# Saving vectorizer
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
) as f:
pickle.dump(self.vectorizer, f)
# TODO: save some config and perform sanity checks?
return
def load(self):
first_tier_learners = []
if self.posteriors_vgf:
# FIXME: hardcoded
with open(
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
# FIXME: hardcoded
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
# FIXME: hardcoded
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
# FIXME: hardcoded
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
metaclassifier = pickle.load(f)
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = "rbf"
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]