gfun_multimodal/gfun/generalizedFunnelling.py

287 lines
9.4 KiB
Python

import os
import sys
sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from vgfs.vanillaFun import VanillaFunGen
from vgfs.wceGen import WceGen
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
transformer,
langs,
embed_dir,
n_jobs,
batch_size,
max_length,
lr,
epochs,
patience,
evaluate_step,
transformer_name,
optimc,
device,
load_trained,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params ----------
self.transformer_name = transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = "mean"
self.load_trained = load_trained
self._init()
def _init(self):
print("[Init GeneralizedFunnelling]")
if self.load_trained is not None:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
self.load_trained
)
# TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
first_tier_parameters=None,
n_jobs=self.n_jobs,
)
self.first_tier_learners.append(fun)
if self.multilingual_vgf:
multilingual_vgf = MultilingualGen(
embed_dir=self.embed_dir,
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=True,
)
self.first_tier_learners.append(multilingual_vgf)
if self.wce_vgf:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
transformer_vgf = TextualTransformerGen(
model_name=self.transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device="cuda",
print_steps=50,
probabilistic=True,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
)
self.first_tier_learners.append(transformer_vgf)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
self._model_id = get_unique_id(
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.trasformer_vgf,
)
print(f"- model id: {self._model_id}")
return self
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
if self.load_trained is not None:
print(f"- loaded trained model! Skipping training...")
# TODO: add support to load only the first tier learners while re-training the metaclassifier
load_only_first_tier = False
if load_only_first_tier:
raise NotImplementedError
return self
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
projections = []
print("- fitting first tier learners")
for vgf in self.first_tier_learners:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections)
self.metaclassifier.fit(agg, lY)
return self
def transform(self, lX):
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
return l_out
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
else:
raise NotImplementedError
return aggregated
def _aggregate_mean(self, first_tier_projections):
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
# Computing mean
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
return aggregated
def get_config(self):
print("\n")
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
for vgf in self.first_tier_learners:
print(vgf)
print("-" * 50)
def save(self):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
) as f:
pickle.dump(self.metaclassifier, f)
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.vectorizer, f)
return
def load(self, model_id):
print(f"- loading model id: {model_id}")
first_tier_learners = []
if self.posteriors_vgf:
with open(
os.path.join(
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
with open(
os.path.join(
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
with open(
os.path.join(
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = "rbf"
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(posterior, multilingual, wce, transformer):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
model_id = ""
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if transformer else ""
return f"{model_id}_{now}"