287 lines
9.4 KiB
Python
287 lines
9.4 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|
|
|
import pickle
|
|
|
|
import numpy as np
|
|
from vgfs.commons import TfidfVectorizerMultilingual
|
|
from vgfs.learners.svms import MetaClassifier, get_learner
|
|
from vgfs.multilingualGen import MultilingualGen
|
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
|
from vgfs.vanillaFun import VanillaFunGen
|
|
from vgfs.wceGen import WceGen
|
|
|
|
|
|
class GeneralizedFunnelling:
|
|
def __init__(
|
|
self,
|
|
posterior,
|
|
wce,
|
|
multilingual,
|
|
transformer,
|
|
langs,
|
|
embed_dir,
|
|
n_jobs,
|
|
batch_size,
|
|
max_length,
|
|
lr,
|
|
epochs,
|
|
patience,
|
|
evaluate_step,
|
|
transformer_name,
|
|
optimc,
|
|
device,
|
|
load_trained,
|
|
):
|
|
# Setting VFGs -----------
|
|
self.posteriors_vgf = posterior
|
|
self.wce_vgf = wce
|
|
self.multilingual_vgf = multilingual
|
|
self.trasformer_vgf = transformer
|
|
# ------------------------
|
|
self.langs = langs
|
|
self.embed_dir = embed_dir
|
|
self.cached = True
|
|
# Transformer VGF params ----------
|
|
self.transformer_name = transformer_name
|
|
self.epochs = epochs
|
|
self.lr_transformer = lr
|
|
self.batch_size_transformer = batch_size
|
|
self.max_length = max_length
|
|
self.early_stopping = True
|
|
self.patience = patience
|
|
self.evaluate_step = evaluate_step
|
|
self.device = device
|
|
# Metaclassifier params ------------
|
|
self.optimc = optimc
|
|
# -------------------
|
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
|
self.n_jobs = n_jobs
|
|
self.first_tier_learners = []
|
|
self.metaclassifier = None
|
|
self.aggfunc = "mean"
|
|
self.load_trained = load_trained
|
|
self._init()
|
|
|
|
def _init(self):
|
|
print("[Init GeneralizedFunnelling]")
|
|
if self.load_trained is not None:
|
|
print("- loading trained VGFs, metaclassifer and vectorizer")
|
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
|
self.load_trained
|
|
)
|
|
# TODO: config like aggfunc, device, n_jobs, etc
|
|
return self
|
|
|
|
if self.posteriors_vgf:
|
|
fun = VanillaFunGen(
|
|
base_learner=get_learner(calibrate=True),
|
|
first_tier_parameters=None,
|
|
n_jobs=self.n_jobs,
|
|
)
|
|
self.first_tier_learners.append(fun)
|
|
|
|
if self.multilingual_vgf:
|
|
multilingual_vgf = MultilingualGen(
|
|
embed_dir=self.embed_dir,
|
|
langs=self.langs,
|
|
n_jobs=self.n_jobs,
|
|
cached=self.cached,
|
|
probabilistic=True,
|
|
)
|
|
self.first_tier_learners.append(multilingual_vgf)
|
|
|
|
if self.wce_vgf:
|
|
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
|
self.first_tier_learners.append(wce_vgf)
|
|
|
|
if self.trasformer_vgf:
|
|
transformer_vgf = TextualTransformerGen(
|
|
model_name=self.transformer_name,
|
|
lr=self.lr_transformer,
|
|
epochs=self.epochs,
|
|
batch_size=self.batch_size_transformer,
|
|
max_length=self.max_length,
|
|
device="cuda",
|
|
print_steps=50,
|
|
probabilistic=True,
|
|
evaluate_step=self.evaluate_step,
|
|
verbose=True,
|
|
patience=self.patience,
|
|
)
|
|
self.first_tier_learners.append(transformer_vgf)
|
|
|
|
self.metaclassifier = MetaClassifier(
|
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
|
meta_parameters=get_params(self.optimc),
|
|
n_jobs=self.n_jobs,
|
|
)
|
|
|
|
self._model_id = get_unique_id(
|
|
self.posteriors_vgf,
|
|
self.multilingual_vgf,
|
|
self.wce_vgf,
|
|
self.trasformer_vgf,
|
|
)
|
|
print(f"- model id: {self._model_id}")
|
|
return self
|
|
|
|
def init_vgfs_vectorizers(self):
|
|
for vgf in self.first_tier_learners:
|
|
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
|
|
vgf.vectorizer = self.vectorizer
|
|
|
|
def fit(self, lX, lY):
|
|
print("[Fitting GeneralizedFunnelling]")
|
|
if self.load_trained is not None:
|
|
print(f"- loaded trained model! Skipping training...")
|
|
# TODO: add support to load only the first tier learners while re-training the metaclassifier
|
|
load_only_first_tier = False
|
|
if load_only_first_tier:
|
|
raise NotImplementedError
|
|
return self
|
|
|
|
self.vectorizer.fit(lX)
|
|
self.init_vgfs_vectorizers()
|
|
|
|
projections = []
|
|
print("- fitting first tier learners")
|
|
for vgf in self.first_tier_learners:
|
|
l_posteriors = vgf.fit_transform(lX, lY)
|
|
projections.append(l_posteriors)
|
|
|
|
agg = self.aggregate(projections)
|
|
self.metaclassifier.fit(agg, lY)
|
|
|
|
return self
|
|
|
|
def transform(self, lX):
|
|
projections = []
|
|
for vgf in self.first_tier_learners:
|
|
l_posteriors = vgf.transform(lX)
|
|
projections.append(l_posteriors)
|
|
agg = self.aggregate(projections)
|
|
l_out = self.metaclassifier.predict_proba(agg)
|
|
return l_out
|
|
|
|
def fit_transform(self, lX, lY):
|
|
return self.fit(lX, lY).transform(lX)
|
|
|
|
def aggregate(self, first_tier_projections):
|
|
if self.aggfunc == "mean":
|
|
aggregated = self._aggregate_mean(first_tier_projections)
|
|
else:
|
|
raise NotImplementedError
|
|
return aggregated
|
|
|
|
def _aggregate_mean(self, first_tier_projections):
|
|
aggregated = {
|
|
lang: np.zeros(data.shape)
|
|
for lang, data in first_tier_projections[0].items()
|
|
}
|
|
for lang_projections in first_tier_projections:
|
|
for lang, projection in lang_projections.items():
|
|
aggregated[lang] += projection
|
|
|
|
# Computing mean
|
|
for lang, projection in aggregated.items():
|
|
aggregated[lang] /= len(first_tier_projections)
|
|
|
|
return aggregated
|
|
|
|
def get_config(self):
|
|
print("\n")
|
|
print("-" * 50)
|
|
print("[GeneralizedFunnelling config]")
|
|
print(f"- model trained on langs: {self.langs}")
|
|
print("-- View Generating Functions configurations:\n")
|
|
|
|
for vgf in self.first_tier_learners:
|
|
print(vgf)
|
|
print("-" * 50)
|
|
|
|
def save(self):
|
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
|
# TODO: save only the first tier learners? what about some model config + sanity checks before loading?
|
|
for vgf in self.first_tier_learners:
|
|
vgf.save_vgf(model_id=self._model_id)
|
|
os.makedirs(os.path.join("models", "metaclassifier"), exist_ok=True)
|
|
with open(
|
|
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), "wb"
|
|
) as f:
|
|
pickle.dump(self.metaclassifier, f)
|
|
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
|
with open(
|
|
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
|
"wb",
|
|
) as f:
|
|
pickle.dump(self.vectorizer, f)
|
|
return
|
|
|
|
def load(self, model_id):
|
|
print(f"- loading model id: {model_id}")
|
|
first_tier_learners = []
|
|
if self.posteriors_vgf:
|
|
with open(
|
|
os.path.join(
|
|
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
|
|
),
|
|
"rb",
|
|
) as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.multilingual_vgf:
|
|
with open(
|
|
os.path.join(
|
|
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
|
|
),
|
|
"rb",
|
|
) as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.wce_vgf:
|
|
with open(
|
|
os.path.join(
|
|
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
|
|
),
|
|
"rb",
|
|
) as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
if self.trasformer_vgf:
|
|
with open(
|
|
os.path.join(
|
|
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
|
),
|
|
"rb",
|
|
) as vgf:
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
with open(
|
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
|
) as f:
|
|
metaclassifier = pickle.load(f)
|
|
with open(
|
|
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
|
) as f:
|
|
vectorizer = pickle.load(f)
|
|
return first_tier_learners, metaclassifier, vectorizer
|
|
|
|
|
|
def get_params(optimc=False):
|
|
if not optimc:
|
|
return None
|
|
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
|
kernel = "rbf"
|
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
|
|
|
|
|
def get_unique_id(posterior, multilingual, wce, transformer):
|
|
from datetime import datetime
|
|
|
|
now = datetime.now().strftime("%y%m%d")
|
|
model_id = ""
|
|
model_id += "p" if posterior else ""
|
|
model_id += "m" if multilingual else ""
|
|
model_id += "w" if wce else ""
|
|
model_id += "t" if transformer else ""
|
|
return f"{model_id}_{now}"
|