implemented fn to save/load trained gfun
This commit is contained in:
parent
6b75483b55
commit
31fb436cf0
|
@ -4,6 +4,7 @@
|
|||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
.vscode/
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
|
|
@ -224,7 +224,7 @@ class MultilingualDataset:
|
|||
self.labels = labels
|
||||
|
||||
def reduce_data(self, langs=["it", "en"], maxn=50):
|
||||
print(f"- Reducing data: {langs} with max {maxn} documents...")
|
||||
print(f"- Reducing data: {langs} with max {maxn} documents...\n")
|
||||
self.set_view(languages=langs)
|
||||
|
||||
data = {
|
||||
|
|
|
@ -33,6 +33,9 @@ class GeneralizedFunnelling:
|
|||
patience,
|
||||
evaluate_step,
|
||||
transformer_name,
|
||||
optimc,
|
||||
device,
|
||||
load_trained,
|
||||
):
|
||||
# Forcing VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
|
@ -43,7 +46,7 @@ class GeneralizedFunnelling:
|
|||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
self.cached = True
|
||||
# Transformer VGF params
|
||||
# Transformer VGF params ----------
|
||||
self.transformer_name = transformer_name
|
||||
self.epochs = epochs
|
||||
self.lr_transformer = lr
|
||||
|
@ -52,16 +55,26 @@ class GeneralizedFunnelling:
|
|||
self.early_stopping = True
|
||||
self.patience = patience
|
||||
self.evaluate_step = evaluate_step
|
||||
self.device = device
|
||||
# Metaclassifier params ------------
|
||||
self.optimc = optimc
|
||||
# -------------------
|
||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
self.n_jobs = n_jobs
|
||||
self.first_tier_learners = []
|
||||
self.metaclassifier = None
|
||||
self.aggfunc = "mean"
|
||||
self.init()
|
||||
self.load_trained = load_trained
|
||||
self._init()
|
||||
|
||||
def init(self):
|
||||
def _init(self):
|
||||
print("[Init GeneralizedFunnelling]")
|
||||
if self.load_trained:
|
||||
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
|
||||
# TODO: config like aggfunc, device, n_jobs, etc
|
||||
return self
|
||||
|
||||
if self.posteriors_vgf:
|
||||
fun = VanillaFunGen(
|
||||
base_learner=get_learner(calibrate=True),
|
||||
|
@ -102,9 +115,10 @@ class GeneralizedFunnelling:
|
|||
|
||||
self.metaclassifier = MetaClassifier(
|
||||
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||
meta_parameters=get_params(),
|
||||
meta_parameters=get_params(self.optimc),
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
return self
|
||||
|
||||
def init_vgfs_vectorizers(self):
|
||||
for vgf in self.first_tier_learners:
|
||||
|
@ -113,6 +127,14 @@ class GeneralizedFunnelling:
|
|||
|
||||
def fit(self, lX, lY):
|
||||
print("[Fitting GeneralizedFunnelling]")
|
||||
if self.load_trained:
|
||||
print(f"- loaded trained model! Skipping training...")
|
||||
load_only_first_tier = False # TODO
|
||||
if load_only_first_tier:
|
||||
projections = []
|
||||
# TODO project, aggregate and fit the metaclassifier
|
||||
return self
|
||||
|
||||
self.vectorizer.fit(lX)
|
||||
self.init_vgfs_vectorizers()
|
||||
|
||||
|
@ -173,6 +195,49 @@ class GeneralizedFunnelling:
|
|||
for vgf in self.first_tier_learners:
|
||||
pprint(vgf.get_config())
|
||||
|
||||
def save(self):
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf()
|
||||
# Saving metaclassifier
|
||||
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
|
||||
pickle.dump(self.metaclassifier, f)
|
||||
# Saving vectorizer
|
||||
with open(
|
||||
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
|
||||
) as f:
|
||||
pickle.dump(self.vectorizer, f)
|
||||
# TODO: save some config and perform sanity checks?
|
||||
return
|
||||
|
||||
def load(self):
|
||||
first_tier_learners = []
|
||||
if self.posteriors_vgf:
|
||||
# FIXME: hardcoded
|
||||
with open(
|
||||
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
|
||||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
if self.multilingual_vgf:
|
||||
# FIXME: hardcoded
|
||||
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
if self.wce_vgf:
|
||||
# FIXME: hardcoded
|
||||
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
if self.trasformer_vgf:
|
||||
# FIXME: hardcoded
|
||||
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
|
||||
metaclassifier = pickle.load(f)
|
||||
with open(
|
||||
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
|
||||
) as f:
|
||||
vectorizer = pickle.load(f)
|
||||
return first_tier_learners, metaclassifier, vectorizer
|
||||
|
||||
|
||||
def get_params(optimc=False):
|
||||
if not optimc:
|
||||
|
|
|
@ -27,11 +27,16 @@ class VanillaFunGen(ViewGen):
|
|||
n_jobs=self.n_jobs,
|
||||
)
|
||||
self.vectorizer = None
|
||||
self.load_trained = False
|
||||
|
||||
def fit(self, lX, lY):
|
||||
if self.load_trained:
|
||||
return self.load_trained()
|
||||
|
||||
print("- fitting VanillaFun View Generating Function")
|
||||
lX = self.vectorizer.transform(lX)
|
||||
self.doc_projector.fit(lX, lY)
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
|
@ -57,3 +62,18 @@ class VanillaFunGen(ViewGen):
|
|||
"first_tier_parameters": self.first_tier_parameters,
|
||||
"n_jobs": self.n_jobs,
|
||||
}
|
||||
|
||||
def save_vgf(self):
|
||||
import pickle
|
||||
from os.path import join
|
||||
from os import makedirs
|
||||
|
||||
model_id = "TODO"
|
||||
|
||||
vgf_name = "vanillaFunGen_todo"
|
||||
_basedir = join("models", "vgfs", "posteriors")
|
||||
makedirs(_basedir, exist_ok=True)
|
||||
_path = join(_basedir, f"{vgf_name}.pkl")
|
||||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
|
94
main.py
94
main.py
|
@ -1,19 +1,22 @@
|
|||
from os.path import expanduser
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from os.path import expanduser
|
||||
from time import time
|
||||
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
from dataManager.multilingualDatset import MultilingualDataset
|
||||
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
from time import time
|
||||
import pickle
|
||||
|
||||
|
||||
# TODO: a cleaner way to save the model?
|
||||
"""
|
||||
TODO:
|
||||
- a cleaner way to save the model? each VGF saved independently (together with
|
||||
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
|
||||
- add documentations sphinx
|
||||
- zero-shot setup
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -26,7 +29,7 @@ def main(args):
|
|||
dataset = (
|
||||
MultilingualDataset(dataset_name="rcv1-2")
|
||||
.load(RCV_DATAPATH)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=250)
|
||||
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
||||
)
|
||||
|
||||
if isinstance(dataset, MultilingualDataset):
|
||||
|
@ -39,7 +42,7 @@ def main(args):
|
|||
|
||||
tinit = time()
|
||||
|
||||
if args.load_pretrained is None:
|
||||
if not args.load_trained:
|
||||
assert any(
|
||||
[
|
||||
args.posteriors,
|
||||
|
@ -50,48 +53,38 @@ def main(args):
|
|||
]
|
||||
), "At least one of VGF must be True"
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
posterior=args.posteriors,
|
||||
multilingual=args.multilingual,
|
||||
wce=args.wce,
|
||||
transformer=args.transformer,
|
||||
langs=dataset.langs(),
|
||||
embed_dir="~/resources/muse_embeddings",
|
||||
n_jobs=args.n_jobs,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
transformer_name=args.transformer_name,
|
||||
)
|
||||
gfun = GeneralizedFunnelling(
|
||||
posterior=args.posteriors,
|
||||
multilingual=args.multilingual,
|
||||
wce=args.wce,
|
||||
transformer=args.transformer,
|
||||
langs=dataset.langs(),
|
||||
embed_dir="~/resources/muse_embeddings",
|
||||
n_jobs=args.n_jobs,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
transformer_name=args.transformer_name,
|
||||
device="cuda",
|
||||
optimc=args.optimc,
|
||||
load_trained=args.load_trained,
|
||||
)
|
||||
|
||||
gfun.get_config()
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
gfun.fit(lX, lY)
|
||||
# if not args.load_model:
|
||||
# gfun.save()
|
||||
|
||||
# Saving Model ------------------------
|
||||
with open("models/gfun/gfun_model.pkl", "wb") as f:
|
||||
print(f"- saving model to {f.name}")
|
||||
pickle.dump(gfun, f)
|
||||
# -------------------------------------
|
||||
preds = gfun.transform(lX)
|
||||
|
||||
preds = gfun.transform(lX)
|
||||
train_eval = evaluate(lY, preds)
|
||||
log_eval(train_eval, phase="train")
|
||||
|
||||
train_eval = evaluate(lY, preds)
|
||||
log_eval(train_eval, phase="train")
|
||||
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
||||
# Loading Model ------------------------
|
||||
if args.load_pretrained is not None:
|
||||
with open("models/gfun/gfun_model.pkl", "rb") as f:
|
||||
print(f"- loading model from {f.name}")
|
||||
gfun = pickle.load(f)
|
||||
timetr = time()
|
||||
# --------------------------------------
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
||||
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
||||
log_eval(test_eval, phase="test")
|
||||
|
@ -102,7 +95,7 @@ def main(args):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--load_pretrained", type=str, default=None)
|
||||
parser.add_argument("-l", "--load_trained", action="store_true")
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
parser.add_argument("--nrows", type=int, default=10000)
|
||||
|
@ -114,6 +107,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-w", "--wce", action="store_true")
|
||||
parser.add_argument("-t", "--transformer", action="store_true")
|
||||
parser.add_argument("--n_jobs", type=int, default=1)
|
||||
parser.add_argument("--optimc", action="store_true")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
|
|
Loading…
Reference in New Issue