implemented fn to save/load trained gfun

This commit is contained in:
Andrea Pedrotti 2023-02-08 14:51:56 +01:00
parent 6b75483b55
commit 31fb436cf0
5 changed files with 135 additions and 55 deletions

1
.gitignore vendored
View File

@ -4,6 +4,7 @@
!.vscode/launch.json !.vscode/launch.json
!.vscode/extensions.json !.vscode/extensions.json
!.vscode/*.code-snippets !.vscode/*.code-snippets
.vscode/
# Local History for Visual Studio Code # Local History for Visual Studio Code
.history/ .history/

View File

@ -224,7 +224,7 @@ class MultilingualDataset:
self.labels = labels self.labels = labels
def reduce_data(self, langs=["it", "en"], maxn=50): def reduce_data(self, langs=["it", "en"], maxn=50):
print(f"- Reducing data: {langs} with max {maxn} documents...") print(f"- Reducing data: {langs} with max {maxn} documents...\n")
self.set_view(languages=langs) self.set_view(languages=langs)
data = { data = {

View File

@ -33,6 +33,9 @@ class GeneralizedFunnelling:
patience, patience,
evaluate_step, evaluate_step,
transformer_name, transformer_name,
optimc,
device,
load_trained,
): ):
# Forcing VFGs ----------- # Forcing VFGs -----------
self.posteriors_vgf = posterior self.posteriors_vgf = posterior
@ -43,7 +46,7 @@ class GeneralizedFunnelling:
self.langs = langs self.langs = langs
self.embed_dir = embed_dir self.embed_dir = embed_dir
self.cached = True self.cached = True
# Transformer VGF params # Transformer VGF params ----------
self.transformer_name = transformer_name self.transformer_name = transformer_name
self.epochs = epochs self.epochs = epochs
self.lr_transformer = lr self.lr_transformer = lr
@ -52,16 +55,26 @@ class GeneralizedFunnelling:
self.early_stopping = True self.early_stopping = True
self.patience = patience self.patience = patience
self.evaluate_step = evaluate_step self.evaluate_step = evaluate_step
self.device = device
# Metaclassifier params ------------
self.optimc = optimc
# ------------------- # -------------------
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.first_tier_learners = [] self.first_tier_learners = []
self.metaclassifier = None self.metaclassifier = None
self.aggfunc = "mean" self.aggfunc = "mean"
self.init() self.load_trained = load_trained
self._init()
def init(self): def _init(self):
print("[Init GeneralizedFunnelling]") print("[Init GeneralizedFunnelling]")
if self.load_trained:
print("- loading trained VGFs, metaclassifer and vectorizer")
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
# TODO: config like aggfunc, device, n_jobs, etc
return self
if self.posteriors_vgf: if self.posteriors_vgf:
fun = VanillaFunGen( fun = VanillaFunGen(
base_learner=get_learner(calibrate=True), base_learner=get_learner(calibrate=True),
@ -102,9 +115,10 @@ class GeneralizedFunnelling:
self.metaclassifier = MetaClassifier( self.metaclassifier = MetaClassifier(
meta_learner=get_learner(calibrate=True, kernel="rbf"), meta_learner=get_learner(calibrate=True, kernel="rbf"),
meta_parameters=get_params(), meta_parameters=get_params(self.optimc),
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
) )
return self
def init_vgfs_vectorizers(self): def init_vgfs_vectorizers(self):
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
@ -113,6 +127,14 @@ class GeneralizedFunnelling:
def fit(self, lX, lY): def fit(self, lX, lY):
print("[Fitting GeneralizedFunnelling]") print("[Fitting GeneralizedFunnelling]")
if self.load_trained:
print(f"- loaded trained model! Skipping training...")
load_only_first_tier = False # TODO
if load_only_first_tier:
projections = []
# TODO project, aggregate and fit the metaclassifier
return self
self.vectorizer.fit(lX) self.vectorizer.fit(lX)
self.init_vgfs_vectorizers() self.init_vgfs_vectorizers()
@ -173,6 +195,49 @@ class GeneralizedFunnelling:
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
pprint(vgf.get_config()) pprint(vgf.get_config())
def save(self):
for vgf in self.first_tier_learners:
vgf.save_vgf()
# Saving metaclassifier
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
pickle.dump(self.metaclassifier, f)
# Saving vectorizer
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
) as f:
pickle.dump(self.vectorizer, f)
# TODO: save some config and perform sanity checks?
return
def load(self):
first_tier_learners = []
if self.posteriors_vgf:
# FIXME: hardcoded
with open(
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.multilingual_vgf:
# FIXME: hardcoded
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.wce_vgf:
# FIXME: hardcoded
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
# FIXME: hardcoded
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
first_tier_learners.append(pickle.load(vgf))
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
metaclassifier = pickle.load(f)
with open(
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
) as f:
vectorizer = pickle.load(f)
return first_tier_learners, metaclassifier, vectorizer
def get_params(optimc=False): def get_params(optimc=False):
if not optimc: if not optimc:

View File

@ -27,11 +27,16 @@ class VanillaFunGen(ViewGen):
n_jobs=self.n_jobs, n_jobs=self.n_jobs,
) )
self.vectorizer = None self.vectorizer = None
self.load_trained = False
def fit(self, lX, lY): def fit(self, lX, lY):
if self.load_trained:
return self.load_trained()
print("- fitting VanillaFun View Generating Function") print("- fitting VanillaFun View Generating Function")
lX = self.vectorizer.transform(lX) lX = self.vectorizer.transform(lX)
self.doc_projector.fit(lX, lY) self.doc_projector.fit(lX, lY)
return self return self
def transform(self, lX): def transform(self, lX):
@ -57,3 +62,18 @@ class VanillaFunGen(ViewGen):
"first_tier_parameters": self.first_tier_parameters, "first_tier_parameters": self.first_tier_parameters,
"n_jobs": self.n_jobs, "n_jobs": self.n_jobs,
} }
def save_vgf(self):
import pickle
from os.path import join
from os import makedirs
model_id = "TODO"
vgf_name = "vanillaFunGen_todo"
_basedir = join("models", "vgfs", "posteriors")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self

94
main.py
View File

@ -1,19 +1,22 @@
from os.path import expanduser import pickle
from argparse import ArgumentParser from argparse import ArgumentParser
from os.path import expanduser
from time import time
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
from evaluation.evaluate import evaluate, log_eval """
TODO:
from time import time - a cleaner way to save the model? each VGF saved independently (together with
import pickle standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
- add documentations sphinx
- zero-shot setup
# TODO: a cleaner way to save the model?
"""
def main(args): def main(args):
@ -26,7 +29,7 @@ def main(args):
dataset = ( dataset = (
MultilingualDataset(dataset_name="rcv1-2") MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH) .load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=250) .reduce_data(langs=["en", "it", "fr"], maxn=100)
) )
if isinstance(dataset, MultilingualDataset): if isinstance(dataset, MultilingualDataset):
@ -39,7 +42,7 @@ def main(args):
tinit = time() tinit = time()
if args.load_pretrained is None: if not args.load_trained:
assert any( assert any(
[ [
args.posteriors, args.posteriors,
@ -50,48 +53,38 @@ def main(args):
] ]
), "At least one of VGF must be True" ), "At least one of VGF must be True"
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
posterior=args.posteriors, posterior=args.posteriors,
multilingual=args.multilingual, multilingual=args.multilingual,
wce=args.wce, wce=args.wce,
transformer=args.transformer, transformer=args.transformer,
langs=dataset.langs(), langs=dataset.langs(),
embed_dir="~/resources/muse_embeddings", embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs, n_jobs=args.n_jobs,
max_length=args.max_length, max_length=args.max_length,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
lr=args.lr, lr=args.lr,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name, transformer_name=args.transformer_name,
) device="cuda",
optimc=args.optimc,
load_trained=args.load_trained,
)
gfun.get_config() gfun.fit(lX, lY)
gfun.fit(lX, lY) # if not args.load_model:
# gfun.save()
# Saving Model ------------------------ preds = gfun.transform(lX)
with open("models/gfun/gfun_model.pkl", "wb") as f:
print(f"- saving model to {f.name}")
pickle.dump(gfun, f)
# -------------------------------------
preds = gfun.transform(lX) train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
train_eval = evaluate(lY, preds) timetr = time()
log_eval(train_eval, phase="train") print(f"- training completed in {timetr - tinit:.2f} seconds")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
# Loading Model ------------------------
if args.load_pretrained is not None:
with open("models/gfun/gfun_model.pkl", "rb") as f:
print(f"- loading model from {f.name}")
gfun = pickle.load(f)
timetr = time()
# --------------------------------------
test_eval = evaluate(lY_te, gfun.transform(lX_te)) test_eval = evaluate(lY_te, gfun.transform(lX_te))
log_eval(test_eval, phase="test") log_eval(test_eval, phase="test")
@ -102,7 +95,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--load_pretrained", type=str, default=None) parser.add_argument("-l", "--load_trained", action="store_true")
# Dataset parameters ------------------- # Dataset parameters -------------------
parser.add_argument("--domains", type=str, default="all") parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--nrows", type=int, default=10000)
@ -114,6 +107,7 @@ if __name__ == "__main__":
parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--transformer", action="store_true") parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1) parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)