gfun_multimodal/gfun/generalizedFunnelling.py

412 lines
14 KiB
Python

import os
import sys
sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from vgfs.learners.svms import MetaClassifier, get_learner
from vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from vgfs.vanillaFun import VanillaFunGen
from vgfs.wceGen import WceGen
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
transformer,
langs,
num_labels,
embed_dir,
n_jobs,
batch_size,
max_length,
lr,
epochs,
patience,
evaluate_step,
transformer_name,
optimc,
device,
load_trained,
dataset_name,
probabilistic,
aggfunc,
load_meta,
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params ----------
self.transformer_name = transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = aggfunc
self.load_trained = load_trained
self.load_first_tier = (
True # TODO: i guess we're always going to load at least the fitst tier
)
self.load_meta = load_meta
self.dataset_name = dataset_name
self._init()
def _init(self):
print("[Init GeneralizedFunnelling]")
assert not (
self.aggfunc == "mean" and self.probabilistic is False
), "When using averaging aggreagation function probabilistic must be True"
if self.load_trained is not None:
# TODO: clean up this code here
print(
"- loading trained VGFs, metaclassifer and vectorizer"
if self.load_meta
else "- loading trained VGFs and vectorizer"
)
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
self.load_trained,
load_first_tier=self.load_first_tier,
load_meta=self.load_meta,
)
if self.metaclassifier is None:
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
n_jobs=self.n_jobs,
)
self.first_tier_learners.append(fun)
if self.multilingual_vgf:
multilingual_vgf = MultilingualGen(
embed_dir=self.embed_dir,
langs=self.langs,
n_jobs=self.n_jobs,
cached=self.cached,
probabilistic=self.probabilistic,
)
self.first_tier_learners.append(multilingual_vgf)
if self.wce_vgf:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.transformer_name,
lr=self.lr_transformer,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
max_length=self.max_length,
device=self.device,
print_steps=50,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
verbose=True,
patience=self.patience,
)
self.first_tier_learners.append(transformer_vgf)
if "attn" in self.aggfunc:
attn_stacking = self.aggfunc.split("_")[1]
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(),
out_dim=self.num_labels,
lr=self.lr_transformer,
patience=self.patience,
num_heads=1,
device=self.device,
epochs=self.epochs,
attn_stacking_type=attn_stacking,
)
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
self._model_id = get_unique_id(
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.trasformer_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
return self
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
vgf.vectorizer = self.vectorizer
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
if self.load_trained is not None:
print(
"- loaded first tier learners!"
if self.load_meta is False
else "- loaded trained model!"
)
if self.load_first_tier is True and self.load_meta is False:
# TODO: clean up this code here
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
projections = []
print("- fitting first tier learners")
for vgf in self.first_tier_learners:
l_posteriors = vgf.fit_transform(lX, lY)
projections.append(l_posteriors)
agg = self.aggregate(projections, lY)
self.metaclassifier.fit(agg, lY)
return self
def transform(self, lX):
projections = []
for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX)
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
return l_out
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def aggregate(self, first_tier_projections, lY=None):
if self.aggfunc == "mean":
aggregated = self._aggregate_mean(first_tier_projections)
elif self.aggfunc == "concat":
aggregated = self._aggregate_concat(first_tier_projections)
# elif self.aggfunc == "attn":
elif "attn" in self.aggfunc:
aggregated = self._aggregate_attn(first_tier_projections, lY)
else:
raise NotImplementedError
return aggregated
def _aggregate_attn(self, first_tier_projections, lY=None):
if lY is None:
# at prediction time
aggregated = self.attn_aggregator.transform(first_tier_projections)
else:
# at training time we must fit the attention layer
self.attn_aggregator.fit(first_tier_projections, lY)
aggregated = self.attn_aggregator.transform(first_tier_projections)
return aggregated
def _aggregate_concat(self, first_tier_projections):
aggregated = {}
for lang in self.langs:
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
return aggregated
def _aggregate_mean(self, first_tier_projections):
aggregated = {
lang: np.zeros(data.shape)
for lang, data in first_tier_projections[0].items()
}
for lang_projections in first_tier_projections:
for lang, projection in lang_projections.items():
aggregated[lang] += projection
for lang, projection in aggregated.items():
aggregated[lang] /= len(first_tier_projections)
return aggregated
def get_config(self):
print("\n")
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
for vgf in self.first_tier_learners:
print(vgf)
print("-" * 50)
def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
with open(
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.vectorizer, f)
if save_first_tier:
self.save_first_tier_learners(model_id=self._model_id)
if save_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.metaclassifier, f)
return
def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self
def load(self, model_id, load_first_tier=True, load_meta=True):
print(f"- loading model id: {model_id}")
first_tier_learners = []
with open(
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
if self.posteriors_vgf:
with open(
os.path.join(
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
with open(
os.path.join(
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
with open(
os.path.join(
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if load_meta:
with open(
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
) as f:
metaclassifier = pickle.load(f)
else:
metaclassifier = None
return first_tier_learners, metaclassifier, vectorizer
def _load_meta(self):
raise NotImplementedError
def _load_posterior(self):
raise NotImplementedError
def _load_multilingual(self):
raise NotImplementedError
def _load_wce(self):
raise NotImplementedError
def _load_transformer(self):
raise NotImplementedError
def get_attn_agg_dim(self, attn_stacking_type="concat"):
if self.probabilistic and "attn" not in self.aggfunc:
return len(self.first_tier_learners) * self.num_labels
elif self.probabilistic and "attn" in self.aggfunc:
if attn_stacking_type == "concat":
return len(self.first_tier_learners) * self.num_labels
elif attn_stacking_type == "mean":
return self.num_labels
else:
raise NotImplementedError
else:
raise NotImplementedError
def get_params(optimc=False):
if not optimc:
return None
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
kernel = "rbf"
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
model_id = ""
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"