implemented fn to save/load trained gfun

This commit is contained in:
Andrea Pedrotti 2023-02-08 14:51:56 +01:00
parent 6b75483b55
commit 31fb436cf0
5 changed files with 135 additions and 55 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
.vscode/
# Local History for Visual Studio Code
.history/

View File

@ -224,7 +224,7 @@ class MultilingualDataset:
self.labels = labels
def reduce_data(self, langs=["it", "en"], maxn=50):
print(f"- Reducing data: {langs} with max {maxn} documents...")
print(f"- Reducing data: {langs} with max {maxn} documents...\n")
self.set_view(languages=langs)
data = {

View File

@ -33,6 +33,9 @@ class GeneralizedFunnelling:
patience,
evaluate_step,
transformer_name,
optimc,
device,
load_trained,
):
# Forcing VFGs -----------
self.posteriors_vgf = posterior
@ -43,7 +46,7 @@ class GeneralizedFunnelling:
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Transformer VGF params
# Transformer VGF params ----------
self.transformer_name = transformer_name
self.epochs = epochs
self.lr_transformer = lr
@ -52,16 +55,26 @@ class GeneralizedFunnelling:
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs
self.first_tier_learners = []
self.metaclassifier = None
self.aggfunc = "mean"
self.init()
self.load_trained = load_trained
self._init()
def init(self):
def _init(self):
print("[Init GeneralizedFunnelling]")
if self.load_trained:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
# TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf:
fun = VanillaFunGen(
base_learner=get_learner(calibrate=True),
@ -102,9 +115,10 @@ class GeneralizedFunnelling:
self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(),
meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs,
)
return self
def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners:
@ -113,6 +127,14 @@ class GeneralizedFunnelling:
def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]")
if self.load_trained:
print(f"- loaded trained model! Skipping training...")
load_only_first_tier = False # TODO
if load_only_first_tier:
projections = []
# TODO project, aggregate and fit the metaclassifier
return self
self.vectorizer.fit(lX)
self.init_vgfs_vectorizers()
@ -173,6 +195,49 @@ class GeneralizedFunnelling:
for vgf in self.first_tier_learners:
pprint(vgf.get_config())
def save(self):
for vgf in self.first_tier_learners:
vgf.save_vgf()
# Saving metaclassifier
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
pickle.dump(self.metaclassifier, f)
# Saving vectorizer
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
) as f:
pickle.dump(self.vectorizer, f)
# TODO: save some config and perform sanity checks?
return
def load(self):
first_tier_learners = []
if self.posteriors_vgf:
# FIXME: hardcoded
with open(
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
# FIXME: hardcoded
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
# FIXME: hardcoded
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
# FIXME: hardcoded
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
metaclassifier = pickle.load(f)
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer
def get_params(optimc=False):
if not optimc:

View File

@ -27,11 +27,16 @@ class VanillaFunGen(ViewGen):
n_jobs=self.n_jobs,
)
self.vectorizer = None
self.load_trained = False
def fit(self, lX, lY):
if self.load_trained:
return self.load_trained()
print("- fitting VanillaFun View Generating Function")
lX = self.vectorizer.transform(lX)
self.doc_projector.fit(lX, lY)
return self
def transform(self, lX):
@ -57,3 +62,18 @@ class VanillaFunGen(ViewGen):
"first_tier_parameters": self.first_tier_parameters,
"n_jobs": self.n_jobs,
}
def save_vgf(self):
import pickle
from os.path import join
from os import makedirs
model_id = "TODO"
vgf_name = "vanillaFunGen_todo"
_basedir = join("models", "vgfs", "posteriors")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self

94
main.py
View File

@ -1,19 +1,22 @@
from os.path import expanduser
import pickle
from argparse import ArgumentParser
from os.path import expanduser
from time import time
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
from evaluation.evaluate import evaluate, log_eval
from time import time
import pickle
# TODO: a cleaner way to save the model?
"""
TODO:
- a cleaner way to save the model? each VGF saved independently (together with
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
- add documentations sphinx
- zero-shot setup
"""
def main(args):
@ -26,7 +29,7 @@ def main(args):
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=250)
.reduce_data(langs=["en", "it", "fr"], maxn=100)
)
if isinstance(dataset, MultilingualDataset):
@ -39,7 +42,7 @@ def main(args):
tinit = time()
if args.load_pretrained is None:
if not args.load_trained:
assert any(
[
args.posteriors,
@ -50,48 +53,38 @@ def main(args):
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(),
embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs,
max_length=args.max_length,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
patience=args.patience,
evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
)
gfun = GeneralizedFunnelling(
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(),
embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs,
max_length=args.max_length,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
patience=args.patience,
evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
device="cuda",
optimc=args.optimc,
load_trained=args.load_trained,
)
gfun.get_config()
gfun.fit(lX, lY)
gfun.fit(lX, lY)
# if not args.load_model:
# gfun.save()
# Saving Model ------------------------
with open("models/gfun/gfun_model.pkl", "wb") as f:
print(f"- saving model to {f.name}")
pickle.dump(gfun, f)
# -------------------------------------
preds = gfun.transform(lX)
preds = gfun.transform(lX)
train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
# Loading Model ------------------------
if args.load_pretrained is not None:
with open("models/gfun/gfun_model.pkl", "rb") as f:
print(f"- loading model from {f.name}")
gfun = pickle.load(f)
timetr = time()
# --------------------------------------
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
test_eval = evaluate(lY_te, gfun.transform(lX_te))
log_eval(test_eval, phase="test")
@ -102,7 +95,7 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--load_pretrained", type=str, default=None)
parser.add_argument("-l", "--load_trained", action="store_true")
# Dataset parameters -------------------
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000)
@ -114,6 +107,7 @@ if __name__ == "__main__":
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true")
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)