248 lines
8.1 KiB
Python
248 lines
8.1 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|
|
|
import pickle
|
|
|
|
import numpy as np
|
|
from vgfs.commons import TfidfVectorizerMultilingual
|
|
from vgfs.learners.svms import MetaClassifier, get_learner
|
|
from vgfs.multilingualGen import MultilingualGen
|
|
from vgfs.transformerGen import TransformerGen
|
|
from vgfs.vanillaFun import VanillaFunGen
|
|
from vgfs.wceGen import WceGen
|
|
|
|
# TODO: save and load gfun model
|
|
|
|
|
|
class GeneralizedFunnelling:
|
|
def __init__(
|
|
self,
|
|
posterior,
|
|
wce,
|
|
multilingual,
|
|
transformer,
|
|
langs,
|
|
embed_dir,
|
|
n_jobs,
|
|
batch_size,
|
|
max_length,
|
|
lr,
|
|
epochs,
|
|
patience,
|
|
evaluate_step,
|
|
transformer_name,
|
|
optimc,
|
|
device,
|
|
load_trained,
|
|
):
|
|
# Forcing VFGs -----------
|
|
self.posteriors_vgf = posterior
|
|
self.wce_vgf = wce
|
|
self.multilingual_vgf = multilingual
|
|
self.trasformer_vgf = transformer
|
|
# ------------------------
|
|
self.langs = langs
|
|
self.embed_dir = embed_dir
|
|
self.cached = True
|
|
# Transformer VGF params ----------
|
|
self.transformer_name = transformer_name
|
|
self.epochs = epochs
|
|
self.lr_transformer = lr
|
|
self.batch_size_transformer = batch_size
|
|
self.max_length = max_length
|
|
self.early_stopping = True
|
|
self.patience = patience
|
|
self.evaluate_step = evaluate_step
|
|
self.device = device
|
|
# Metaclassifier params ------------
|
|
self.optimc = optimc
|
|
# -------------------
|
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
|
self.n_jobs = n_jobs
|
|
self.first_tier_learners = []
|
|
self.metaclassifier = None
|
|
self.aggfunc = "mean"
|
|
self.load_trained = load_trained
|
|
self._init()
|
|
|
|
def _init(self):
|
|
print("[Init GeneralizedFunnelling]")
|
|
if self.load_trained:
|
|
print("- loading trained VGFs, metaclassifer and vectorizer")
|
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
|
|
# TODO: config like aggfunc, device, n_jobs, etc
|
|
return self
|
|
|
|
if self.posteriors_vgf:
|
|
fun = VanillaFunGen(
|
|
base_learner=get_learner(calibrate=True),
|
|
first_tier_parameters=None,
|
|
n_jobs=self.n_jobs,
|
|
)
|
|
self.first_tier_learners.append(fun)
|
|
|
|
if self.multilingual_vgf:
|
|
multilingual_vgf = MultilingualGen(
|
|
embed_dir=self.embed_dir,
|
|
langs=self.langs,
|
|
n_jobs=self.n_jobs,
|
|
cached=self.cached,
|
|
probabilistic=True,
|
|
)
|
|
self.first_tier_learners.append(multilingual_vgf)
|
|
|
|
if self.wce_vgf:
|
|
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
|
self.first_tier_learners.append(wce_vgf)
|
|
|
|
if self.trasformer_vgf:
|
|
transformer_vgf = TransformerGen(
|
|
model_name=self.transformer_name,
|
|
lr=self.lr_transformer,
|
|
epochs=self.epochs,
|
|
batch_size=self.batch_size_transformer,
|
|
max_length=self.max_length,
|
|
device="cuda",
|
|
print_steps=50,
|
|
probabilistic=True,
|
|
evaluate_step=self.evaluate_step,
|
|
verbose=True,
|
|
patience=self.patience,
|
|
)
|
|
self.first_tier_learners.append(transformer_vgf)
|
|
|
|
self.metaclassifier = MetaClassifier(
|
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
|
meta_parameters=get_params(self.optimc),
|
|
n_jobs=self.n_jobs,
|
|
)
|
|
return self
|
|
|
|
def init_vgfs_vectorizers(self):
|
|
for vgf in self.first_tier_learners:
|
|
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
|
|
vgf.vectorizer = self.vectorizer
|
|
|
|
def fit(self, lX, lY):
|
|
print("[Fitting GeneralizedFunnelling]")
|
|
if self.load_trained:
|
|
print(f"- loaded trained model! Skipping training...")
|
|
load_only_first_tier = False # TODO
|
|
if load_only_first_tier:
|
|
projections = []
|
|
# TODO project, aggregate and fit the metaclassifier
|
|
return self
|
|
|
|
self.vectorizer.fit(lX)
|
|
self.init_vgfs_vectorizers()
|
|
|
|
projections = []
|
|
print("- fitting first tier learners")
|
|
for vgf in self.first_tier_learners:
|
|
l_posteriors = vgf.fit_transform(lX, lY)
|
|
projections.append(l_posteriors)
|
|
|
|
agg = self.aggregate(projections)
|
|
self.metaclassifier.fit(agg, lY)
|
|
|
|
return self
|
|
|
|
def transform(self, lX):
|
|
projections = []
|
|
for vgf in self.first_tier_learners:
|
|
l_posteriors = vgf.transform(lX)
|
|
projections.append(l_posteriors)
|
|
agg = self.aggregate(projections)
|
|
l_out = self.metaclassifier.predict_proba(agg)
|
|
return l_out
|
|
|
|
def fit_transform(self, lX, lY):
|
|
return self.fit(lX, lY).transform(lX)
|
|
|
|
def aggregate(self, first_tier_projections):
|
|
if self.aggfunc == "mean":
|
|
aggregated = self._aggregate_mean(first_tier_projections)
|
|
else:
|
|
raise NotImplementedError
|
|
return aggregated
|
|
|
|
def _aggregate_mean(self, first_tier_projections):
|
|
# TODO: deafult dict for one-liner?
|
|
aggregated = {
|
|
lang: np.zeros(data.shape)
|
|
for lang, data in first_tier_projections[0].items()
|
|
}
|
|
for lang_projections in first_tier_projections:
|
|
for lang, projection in lang_projections.items():
|
|
aggregated[lang] += projection
|
|
|
|
# Computing mean
|
|
for lang, projection in aggregated.items():
|
|
aggregated[lang] /= len(first_tier_projections)
|
|
|
|
return aggregated
|
|
|
|
def get_config(self):
|
|
from pprint import pprint
|
|
|
|
# TODO
|
|
print("[GeneralizedFunnelling config]")
|
|
print(f"- langs: {self.langs}")
|
|
print("-- vgfs:")
|
|
|
|
for vgf in self.first_tier_learners:
|
|
pprint(vgf.get_config())
|
|
|
|
def save(self):
|
|
for vgf in self.first_tier_learners:
|
|
vgf.save_vgf()
|
|
# Saving metaclassifier
|
|
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
|
|
pickle.dump(self.metaclassifier, f)
|
|
# Saving vectorizer
|
|
with open(
|
|
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
|
|
) as f:
|
|
pickle.dump(self.vectorizer, f)
|
|
# TODO: save some config and perform sanity checks?
|
|
return
|
|
|
|
def load(self):
|
|
first_tier_learners = []
|
|
if self.posteriors_vgf:
|
|
# FIXME: hardcoded
|
|
with open(
|
|
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
|
|
"rb",
|
|
) as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.multilingual_vgf:
|
|
# FIXME: hardcoded
|
|
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.wce_vgf:
|
|
# FIXME: hardcoded
|
|
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.trasformer_vgf:
|
|
# FIXME: hardcoded
|
|
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
|
|
metaclassifier = pickle.load(f)
|
|
with open(
|
|
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
|
|
) as f:
|
|
vectorizer = pickle.load(f)
|
|
return first_tier_learners, metaclassifier, vectorizer
|
|
|
|
|
|
def get_params(optimc=False):
|
|
if not optimc:
|
|
return None
|
|
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
|
kernel = "rbf"
|
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|