2023-02-07 18:40:17 +01:00
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
import numpy as np
|
2023-02-10 18:29:58 +01:00
|
|
|
from vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
2023-02-07 18:40:17 +01:00
|
|
|
from vgfs.learners.svms import MetaClassifier, get_learner
|
|
|
|
from vgfs.multilingualGen import MultilingualGen
|
2023-02-09 16:47:17 +01:00
|
|
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
2023-02-07 18:40:17 +01:00
|
|
|
from vgfs.vanillaFun import VanillaFunGen
|
|
|
|
from vgfs.wceGen import WceGen
|
|
|
|
|
|
|
|
|
|
|
|
class GeneralizedFunnelling:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
posterior,
|
|
|
|
wce,
|
|
|
|
multilingual,
|
|
|
|
transformer,
|
|
|
|
langs,
|
2023-02-13 15:01:50 +01:00
|
|
|
num_labels,
|
2023-02-07 18:40:17 +01:00
|
|
|
embed_dir,
|
|
|
|
n_jobs,
|
|
|
|
batch_size,
|
|
|
|
max_length,
|
|
|
|
lr,
|
|
|
|
epochs,
|
|
|
|
patience,
|
|
|
|
evaluate_step,
|
|
|
|
transformer_name,
|
2023-02-08 14:51:56 +01:00
|
|
|
optimc,
|
|
|
|
device,
|
|
|
|
load_trained,
|
2023-02-10 11:37:32 +01:00
|
|
|
dataset_name,
|
2023-02-10 12:58:26 +01:00
|
|
|
probabilistic,
|
|
|
|
aggfunc,
|
2023-02-13 15:01:50 +01:00
|
|
|
load_meta,
|
2023-02-07 18:40:17 +01:00
|
|
|
):
|
2023-02-08 16:06:24 +01:00
|
|
|
# Setting VFGs -----------
|
2023-02-07 18:40:17 +01:00
|
|
|
self.posteriors_vgf = posterior
|
|
|
|
self.wce_vgf = wce
|
|
|
|
self.multilingual_vgf = multilingual
|
|
|
|
self.trasformer_vgf = transformer
|
2023-02-10 12:58:26 +01:00
|
|
|
self.probabilistic = probabilistic
|
2023-02-13 15:01:50 +01:00
|
|
|
self.num_labels = num_labels
|
2023-02-07 18:40:17 +01:00
|
|
|
# ------------------------
|
|
|
|
self.langs = langs
|
|
|
|
self.embed_dir = embed_dir
|
|
|
|
self.cached = True
|
2023-02-08 14:51:56 +01:00
|
|
|
# Transformer VGF params ----------
|
2023-02-07 18:40:17 +01:00
|
|
|
self.transformer_name = transformer_name
|
|
|
|
self.epochs = epochs
|
|
|
|
self.lr_transformer = lr
|
|
|
|
self.batch_size_transformer = batch_size
|
|
|
|
self.max_length = max_length
|
|
|
|
self.early_stopping = True
|
|
|
|
self.patience = patience
|
|
|
|
self.evaluate_step = evaluate_step
|
2023-02-08 14:51:56 +01:00
|
|
|
self.device = device
|
|
|
|
# Metaclassifier params ------------
|
|
|
|
self.optimc = optimc
|
2023-02-07 18:40:17 +01:00
|
|
|
# -------------------
|
|
|
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
|
|
|
self.n_jobs = n_jobs
|
|
|
|
self.first_tier_learners = []
|
|
|
|
self.metaclassifier = None
|
2023-02-10 12:58:26 +01:00
|
|
|
self.aggfunc = aggfunc
|
2023-02-08 14:51:56 +01:00
|
|
|
self.load_trained = load_trained
|
2023-02-13 15:01:50 +01:00
|
|
|
self.load_first_tier = (
|
|
|
|
True # TODO: i guess we're always going to load at least the fitst tier
|
|
|
|
)
|
|
|
|
self.load_meta = load_meta
|
2023-02-10 11:37:32 +01:00
|
|
|
self.dataset_name = dataset_name
|
2023-02-08 14:51:56 +01:00
|
|
|
self._init()
|
2023-02-07 18:40:17 +01:00
|
|
|
|
2023-02-08 14:51:56 +01:00
|
|
|
def _init(self):
|
2023-02-07 18:40:17 +01:00
|
|
|
print("[Init GeneralizedFunnelling]")
|
2023-02-10 12:58:26 +01:00
|
|
|
assert not (
|
|
|
|
self.aggfunc == "mean" and self.probabilistic is False
|
|
|
|
), "When using averaging aggreagation function probabilistic must be True"
|
2023-02-08 16:06:24 +01:00
|
|
|
if self.load_trained is not None:
|
2023-02-13 15:01:50 +01:00
|
|
|
# TODO: clean up this code here
|
|
|
|
print(
|
|
|
|
"- loading trained VGFs, metaclassifer and vectorizer"
|
|
|
|
if self.load_meta
|
|
|
|
else "- loading trained VGFs and vectorizer"
|
|
|
|
)
|
2023-02-08 16:06:24 +01:00
|
|
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load(
|
2023-02-13 15:01:50 +01:00
|
|
|
self.load_trained,
|
|
|
|
load_first_tier=self.load_first_tier,
|
|
|
|
load_meta=self.load_meta,
|
2023-02-08 16:06:24 +01:00
|
|
|
)
|
2023-02-13 15:01:50 +01:00
|
|
|
if self.metaclassifier is None:
|
|
|
|
self.metaclassifier = MetaClassifier(
|
|
|
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
|
|
|
meta_parameters=get_params(self.optimc),
|
|
|
|
n_jobs=self.n_jobs,
|
|
|
|
)
|
|
|
|
|
|
|
|
if "attn" in self.aggfunc:
|
|
|
|
attn_stacking = self.aggfunc.split("_")[1]
|
|
|
|
self.attn_aggregator = AttentionAggregator(
|
|
|
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
|
|
|
out_dim=self.num_labels,
|
|
|
|
lr=self.lr_transformer,
|
|
|
|
patience=self.patience,
|
|
|
|
num_heads=1,
|
|
|
|
device=self.device,
|
|
|
|
epochs=self.epochs,
|
|
|
|
attn_stacking_type=attn_stacking,
|
|
|
|
)
|
|
|
|
return self
|
2023-02-08 14:51:56 +01:00
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
if self.posteriors_vgf:
|
|
|
|
fun = VanillaFunGen(
|
|
|
|
base_learner=get_learner(calibrate=True),
|
|
|
|
n_jobs=self.n_jobs,
|
|
|
|
)
|
|
|
|
self.first_tier_learners.append(fun)
|
|
|
|
|
|
|
|
if self.multilingual_vgf:
|
|
|
|
multilingual_vgf = MultilingualGen(
|
|
|
|
embed_dir=self.embed_dir,
|
|
|
|
langs=self.langs,
|
|
|
|
n_jobs=self.n_jobs,
|
|
|
|
cached=self.cached,
|
2023-02-10 12:58:26 +01:00
|
|
|
probabilistic=self.probabilistic,
|
2023-02-07 18:40:17 +01:00
|
|
|
)
|
|
|
|
self.first_tier_learners.append(multilingual_vgf)
|
|
|
|
|
|
|
|
if self.wce_vgf:
|
|
|
|
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
|
|
|
self.first_tier_learners.append(wce_vgf)
|
|
|
|
|
|
|
|
if self.trasformer_vgf:
|
2023-02-09 16:47:17 +01:00
|
|
|
transformer_vgf = TextualTransformerGen(
|
2023-02-10 12:58:26 +01:00
|
|
|
dataset_name=self.dataset_name,
|
2023-02-07 18:40:17 +01:00
|
|
|
model_name=self.transformer_name,
|
|
|
|
lr=self.lr_transformer,
|
|
|
|
epochs=self.epochs,
|
|
|
|
batch_size=self.batch_size_transformer,
|
|
|
|
max_length=self.max_length,
|
2023-02-13 15:01:50 +01:00
|
|
|
device=self.device,
|
2023-02-07 18:40:17 +01:00
|
|
|
print_steps=50,
|
2023-02-10 12:58:26 +01:00
|
|
|
probabilistic=self.probabilistic,
|
2023-02-07 18:40:17 +01:00
|
|
|
evaluate_step=self.evaluate_step,
|
|
|
|
verbose=True,
|
|
|
|
patience=self.patience,
|
|
|
|
)
|
|
|
|
self.first_tier_learners.append(transformer_vgf)
|
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
if "attn" in self.aggfunc:
|
|
|
|
attn_stacking = self.aggfunc.split("_")[1]
|
2023-02-10 18:29:58 +01:00
|
|
|
self.attn_aggregator = AttentionAggregator(
|
|
|
|
embed_dim=self.get_attn_agg_dim(),
|
|
|
|
out_dim=self.num_labels,
|
2023-02-13 15:01:50 +01:00
|
|
|
lr=self.lr_transformer,
|
|
|
|
patience=self.patience,
|
2023-02-10 18:29:58 +01:00
|
|
|
num_heads=1,
|
|
|
|
device=self.device,
|
|
|
|
epochs=self.epochs,
|
2023-02-13 15:01:50 +01:00
|
|
|
attn_stacking_type=attn_stacking,
|
2023-02-10 18:29:58 +01:00
|
|
|
)
|
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
self.metaclassifier = MetaClassifier(
|
|
|
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
2023-02-08 14:51:56 +01:00
|
|
|
meta_parameters=get_params(self.optimc),
|
2023-02-07 18:40:17 +01:00
|
|
|
n_jobs=self.n_jobs,
|
|
|
|
)
|
2023-02-08 16:06:24 +01:00
|
|
|
|
|
|
|
self._model_id = get_unique_id(
|
|
|
|
self.posteriors_vgf,
|
|
|
|
self.multilingual_vgf,
|
|
|
|
self.wce_vgf,
|
|
|
|
self.trasformer_vgf,
|
2023-02-13 15:01:50 +01:00
|
|
|
self.aggfunc,
|
2023-02-08 16:06:24 +01:00
|
|
|
)
|
|
|
|
print(f"- model id: {self._model_id}")
|
2023-02-08 14:51:56 +01:00
|
|
|
return self
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
def init_vgfs_vectorizers(self):
|
|
|
|
for vgf in self.first_tier_learners:
|
|
|
|
if isinstance(vgf, (VanillaFunGen, MultilingualGen, WceGen)):
|
|
|
|
vgf.vectorizer = self.vectorizer
|
|
|
|
|
|
|
|
def fit(self, lX, lY):
|
|
|
|
print("[Fitting GeneralizedFunnelling]")
|
2023-02-08 16:06:24 +01:00
|
|
|
if self.load_trained is not None:
|
2023-02-13 15:01:50 +01:00
|
|
|
print(
|
|
|
|
"- loaded first tier learners!"
|
|
|
|
if self.load_meta is False
|
|
|
|
else "- loaded trained model!"
|
|
|
|
)
|
|
|
|
if self.load_first_tier is True and self.load_meta is False:
|
|
|
|
# TODO: clean up this code here
|
|
|
|
projections = []
|
|
|
|
for vgf in self.first_tier_learners:
|
|
|
|
l_posteriors = vgf.transform(lX)
|
|
|
|
projections.append(l_posteriors)
|
|
|
|
agg = self.aggregate(projections, lY)
|
|
|
|
self.metaclassifier.fit(agg, lY)
|
2023-02-08 14:51:56 +01:00
|
|
|
return self
|
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
self.vectorizer.fit(lX)
|
|
|
|
self.init_vgfs_vectorizers()
|
|
|
|
|
|
|
|
projections = []
|
|
|
|
print("- fitting first tier learners")
|
|
|
|
for vgf in self.first_tier_learners:
|
|
|
|
l_posteriors = vgf.fit_transform(lX, lY)
|
|
|
|
projections.append(l_posteriors)
|
|
|
|
|
2023-02-10 18:29:58 +01:00
|
|
|
agg = self.aggregate(projections, lY)
|
2023-02-07 18:40:17 +01:00
|
|
|
self.metaclassifier.fit(agg, lY)
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
def transform(self, lX):
|
|
|
|
projections = []
|
|
|
|
for vgf in self.first_tier_learners:
|
|
|
|
l_posteriors = vgf.transform(lX)
|
|
|
|
projections.append(l_posteriors)
|
|
|
|
agg = self.aggregate(projections)
|
|
|
|
l_out = self.metaclassifier.predict_proba(agg)
|
|
|
|
return l_out
|
|
|
|
|
|
|
|
def fit_transform(self, lX, lY):
|
|
|
|
return self.fit(lX, lY).transform(lX)
|
|
|
|
|
2023-02-10 18:29:58 +01:00
|
|
|
def aggregate(self, first_tier_projections, lY=None):
|
2023-02-07 18:40:17 +01:00
|
|
|
if self.aggfunc == "mean":
|
|
|
|
aggregated = self._aggregate_mean(first_tier_projections)
|
2023-02-10 12:58:26 +01:00
|
|
|
elif self.aggfunc == "concat":
|
|
|
|
aggregated = self._aggregate_concat(first_tier_projections)
|
2023-02-13 15:01:50 +01:00
|
|
|
# elif self.aggfunc == "attn":
|
|
|
|
elif "attn" in self.aggfunc:
|
2023-02-10 18:29:58 +01:00
|
|
|
aggregated = self._aggregate_attn(first_tier_projections, lY)
|
2023-02-07 18:40:17 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
return aggregated
|
|
|
|
|
2023-02-10 18:29:58 +01:00
|
|
|
def _aggregate_attn(self, first_tier_projections, lY=None):
|
|
|
|
if lY is None:
|
|
|
|
# at prediction time
|
|
|
|
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
|
|
|
else:
|
|
|
|
# at training time we must fit the attention layer
|
|
|
|
self.attn_aggregator.fit(first_tier_projections, lY)
|
|
|
|
aggregated = self.attn_aggregator.transform(first_tier_projections)
|
|
|
|
return aggregated
|
|
|
|
|
2023-02-10 12:58:26 +01:00
|
|
|
def _aggregate_concat(self, first_tier_projections):
|
|
|
|
aggregated = {}
|
|
|
|
for lang in self.langs:
|
|
|
|
aggregated[lang] = np.hstack([v[lang] for v in first_tier_projections])
|
|
|
|
return aggregated
|
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
def _aggregate_mean(self, first_tier_projections):
|
|
|
|
aggregated = {
|
|
|
|
lang: np.zeros(data.shape)
|
|
|
|
for lang, data in first_tier_projections[0].items()
|
|
|
|
}
|
|
|
|
for lang_projections in first_tier_projections:
|
|
|
|
for lang, projection in lang_projections.items():
|
|
|
|
aggregated[lang] += projection
|
|
|
|
|
|
|
|
for lang, projection in aggregated.items():
|
|
|
|
aggregated[lang] /= len(first_tier_projections)
|
|
|
|
|
|
|
|
return aggregated
|
|
|
|
|
|
|
|
def get_config(self):
|
2023-02-08 16:06:24 +01:00
|
|
|
print("\n")
|
|
|
|
print("-" * 50)
|
2023-02-07 18:40:17 +01:00
|
|
|
print("[GeneralizedFunnelling config]")
|
2023-02-08 16:06:24 +01:00
|
|
|
print(f"- model trained on langs: {self.langs}")
|
|
|
|
print("-- View Generating Functions configurations:\n")
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
for vgf in self.first_tier_learners:
|
2023-02-08 16:06:24 +01:00
|
|
|
print(vgf)
|
|
|
|
print("-" * 50)
|
2023-02-07 18:40:17 +01:00
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
def save(self, save_first_tier=True, save_meta=True):
|
2023-02-08 16:06:24 +01:00
|
|
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
2023-02-13 15:01:50 +01:00
|
|
|
|
2023-02-08 16:06:24 +01:00
|
|
|
os.makedirs(os.path.join("models", "vectorizer"), exist_ok=True)
|
2023-02-08 14:51:56 +01:00
|
|
|
with open(
|
2023-02-08 16:06:24 +01:00
|
|
|
os.path.join("models", "vectorizer", f"vectorizer_{self._model_id}.pkl"),
|
|
|
|
"wb",
|
2023-02-08 14:51:56 +01:00
|
|
|
) as f:
|
|
|
|
pickle.dump(self.vectorizer, f)
|
2023-02-13 15:01:50 +01:00
|
|
|
|
|
|
|
if save_first_tier:
|
|
|
|
self.save_first_tier_learners(model_id=self._model_id)
|
|
|
|
|
|
|
|
if save_meta:
|
|
|
|
with open(
|
|
|
|
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
|
|
|
|
"wb",
|
|
|
|
) as f:
|
|
|
|
pickle.dump(self.metaclassifier, f)
|
2023-02-08 14:51:56 +01:00
|
|
|
return
|
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
def save_first_tier_learners(self, model_id):
|
|
|
|
for vgf in self.first_tier_learners:
|
|
|
|
vgf.save_vgf(model_id=self._model_id)
|
|
|
|
return self
|
|
|
|
|
|
|
|
def load(self, model_id, load_first_tier=True, load_meta=True):
|
2023-02-08 16:06:24 +01:00
|
|
|
print(f"- loading model id: {model_id}")
|
2023-02-08 14:51:56 +01:00
|
|
|
first_tier_learners = []
|
2023-02-13 15:01:50 +01:00
|
|
|
|
|
|
|
with open(
|
|
|
|
os.path.join("models", "vectorizer", f"vectorizer_{model_id}.pkl"), "rb"
|
|
|
|
) as f:
|
|
|
|
vectorizer = pickle.load(f)
|
|
|
|
|
2023-02-08 14:51:56 +01:00
|
|
|
if self.posteriors_vgf:
|
|
|
|
with open(
|
2023-02-08 16:06:24 +01:00
|
|
|
os.path.join(
|
|
|
|
"models", "vgfs", "posterior", f"vanillaFunGen_{model_id}.pkl"
|
|
|
|
),
|
2023-02-08 14:51:56 +01:00
|
|
|
"rb",
|
|
|
|
) as vgf:
|
|
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
|
|
if self.multilingual_vgf:
|
2023-02-08 16:06:24 +01:00
|
|
|
with open(
|
|
|
|
os.path.join(
|
|
|
|
"models", "vgfs", "multilingual", f"multilingualGen_{model_id}.pkl"
|
|
|
|
),
|
|
|
|
"rb",
|
|
|
|
) as vgf:
|
2023-02-08 14:51:56 +01:00
|
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
|
|
if self.wce_vgf:
|
2023-02-08 16:06:24 +01:00
|
|
|
with open(
|
|
|
|
os.path.join(
|
|
|
|
"models", "vgfs", "wordclass", f"wordClassGen_{model_id}.pkl"
|
|
|
|
),
|
|
|
|
"rb",
|
|
|
|
) as vgf:
|
2023-02-08 14:51:56 +01:00
|
|
|
first_tier_learners.append(pickle.load(vgf))
|
|
|
|
if self.trasformer_vgf:
|
2023-02-08 16:06:24 +01:00
|
|
|
with open(
|
|
|
|
os.path.join(
|
|
|
|
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
|
|
|
),
|
|
|
|
"rb",
|
|
|
|
) as vgf:
|
2023-02-08 14:51:56 +01:00
|
|
|
first_tier_learners.append(pickle.load(vgf))
|
2023-02-13 15:01:50 +01:00
|
|
|
|
|
|
|
if load_meta:
|
|
|
|
with open(
|
|
|
|
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
|
|
|
) as f:
|
|
|
|
metaclassifier = pickle.load(f)
|
|
|
|
else:
|
|
|
|
metaclassifier = None
|
2023-02-08 14:51:56 +01:00
|
|
|
return first_tier_learners, metaclassifier, vectorizer
|
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
def _load_meta(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def _load_posterior(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def _load_multilingual(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def _load_wce(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def _load_transformer(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def get_attn_agg_dim(self, attn_stacking_type="concat"):
|
|
|
|
if self.probabilistic and "attn" not in self.aggfunc:
|
|
|
|
return len(self.first_tier_learners) * self.num_labels
|
|
|
|
elif self.probabilistic and "attn" in self.aggfunc:
|
|
|
|
if attn_stacking_type == "concat":
|
|
|
|
return len(self.first_tier_learners) * self.num_labels
|
|
|
|
elif attn_stacking_type == "mean":
|
|
|
|
return self.num_labels
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2023-02-10 18:29:58 +01:00
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
def get_params(optimc=False):
|
|
|
|
if not optimc:
|
|
|
|
return None
|
|
|
|
c_range = [1e4, 1e3, 1e2, 1e1, 1, 1e-1]
|
|
|
|
kernel = "rbf"
|
|
|
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
2023-02-08 16:06:24 +01:00
|
|
|
|
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
|
2023-02-08 16:06:24 +01:00
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
now = datetime.now().strftime("%y%m%d")
|
|
|
|
model_id = ""
|
|
|
|
model_id += "p" if posterior else ""
|
|
|
|
model_id += "m" if multilingual else ""
|
|
|
|
model_id += "w" if wce else ""
|
|
|
|
model_id += "t" if transformer else ""
|
2023-02-13 15:01:50 +01:00
|
|
|
model_id += f"_{aggfunc}"
|
2023-02-08 16:06:24 +01:00
|
|
|
return f"{model_id}_{now}"
|