implemented fn to save/load trained gfun
This commit is contained in:
parent
6b75483b55
commit
31fb436cf0
|
@ -4,6 +4,7 @@
|
||||||
!.vscode/launch.json
|
!.vscode/launch.json
|
||||||
!.vscode/extensions.json
|
!.vscode/extensions.json
|
||||||
!.vscode/*.code-snippets
|
!.vscode/*.code-snippets
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# Local History for Visual Studio Code
|
# Local History for Visual Studio Code
|
||||||
.history/
|
.history/
|
||||||
|
|
|
@ -224,7 +224,7 @@ class MultilingualDataset:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
def reduce_data(self, langs=["it", "en"], maxn=50):
|
def reduce_data(self, langs=["it", "en"], maxn=50):
|
||||||
print(f"- Reducing data: {langs} with max {maxn} documents...")
|
print(f"- Reducing data: {langs} with max {maxn} documents...\n")
|
||||||
self.set_view(languages=langs)
|
self.set_view(languages=langs)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|
|
@ -33,6 +33,9 @@ class GeneralizedFunnelling:
|
||||||
patience,
|
patience,
|
||||||
evaluate_step,
|
evaluate_step,
|
||||||
transformer_name,
|
transformer_name,
|
||||||
|
optimc,
|
||||||
|
device,
|
||||||
|
load_trained,
|
||||||
):
|
):
|
||||||
# Forcing VFGs -----------
|
# Forcing VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
|
@ -43,7 +46,7 @@ class GeneralizedFunnelling:
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
self.cached = True
|
self.cached = True
|
||||||
# Transformer VGF params
|
# Transformer VGF params ----------
|
||||||
self.transformer_name = transformer_name
|
self.transformer_name = transformer_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.lr_transformer = lr
|
self.lr_transformer = lr
|
||||||
|
@ -52,16 +55,26 @@ class GeneralizedFunnelling:
|
||||||
self.early_stopping = True
|
self.early_stopping = True
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.evaluate_step = evaluate_step
|
self.evaluate_step = evaluate_step
|
||||||
|
self.device = device
|
||||||
|
# Metaclassifier params ------------
|
||||||
|
self.optimc = optimc
|
||||||
# -------------------
|
# -------------------
|
||||||
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.first_tier_learners = []
|
self.first_tier_learners = []
|
||||||
self.metaclassifier = None
|
self.metaclassifier = None
|
||||||
self.aggfunc = "mean"
|
self.aggfunc = "mean"
|
||||||
self.init()
|
self.load_trained = load_trained
|
||||||
|
self._init()
|
||||||
|
|
||||||
def init(self):
|
def _init(self):
|
||||||
print("[Init GeneralizedFunnelling]")
|
print("[Init GeneralizedFunnelling]")
|
||||||
|
if self.load_trained:
|
||||||
|
print("- loading trained VGFs, metaclassifer and vectorizer")
|
||||||
|
self.first_tier_learners, self.metaclassifier, self.vectorizer = self.load()
|
||||||
|
# TODO: config like aggfunc, device, n_jobs, etc
|
||||||
|
return self
|
||||||
|
|
||||||
if self.posteriors_vgf:
|
if self.posteriors_vgf:
|
||||||
fun = VanillaFunGen(
|
fun = VanillaFunGen(
|
||||||
base_learner=get_learner(calibrate=True),
|
base_learner=get_learner(calibrate=True),
|
||||||
|
@ -102,9 +115,10 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
self.metaclassifier = MetaClassifier(
|
self.metaclassifier = MetaClassifier(
|
||||||
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
meta_learner=get_learner(calibrate=True, kernel="rbf"),
|
||||||
meta_parameters=get_params(),
|
meta_parameters=get_params(self.optimc),
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def init_vgfs_vectorizers(self):
|
def init_vgfs_vectorizers(self):
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
|
@ -113,6 +127,14 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
print("[Fitting GeneralizedFunnelling]")
|
print("[Fitting GeneralizedFunnelling]")
|
||||||
|
if self.load_trained:
|
||||||
|
print(f"- loaded trained model! Skipping training...")
|
||||||
|
load_only_first_tier = False # TODO
|
||||||
|
if load_only_first_tier:
|
||||||
|
projections = []
|
||||||
|
# TODO project, aggregate and fit the metaclassifier
|
||||||
|
return self
|
||||||
|
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
self.init_vgfs_vectorizers()
|
self.init_vgfs_vectorizers()
|
||||||
|
|
||||||
|
@ -173,6 +195,49 @@ class GeneralizedFunnelling:
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
pprint(vgf.get_config())
|
pprint(vgf.get_config())
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
for vgf in self.first_tier_learners:
|
||||||
|
vgf.save_vgf()
|
||||||
|
# Saving metaclassifier
|
||||||
|
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "wb") as f:
|
||||||
|
pickle.dump(self.metaclassifier, f)
|
||||||
|
# Saving vectorizer
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "wb"
|
||||||
|
) as f:
|
||||||
|
pickle.dump(self.vectorizer, f)
|
||||||
|
# TODO: save some config and perform sanity checks?
|
||||||
|
return
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
first_tier_learners = []
|
||||||
|
if self.posteriors_vgf:
|
||||||
|
# FIXME: hardcoded
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "vgfs", "posteriors", "vanillaFunGen_todo.pkl"),
|
||||||
|
"rb",
|
||||||
|
) as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
if self.multilingual_vgf:
|
||||||
|
# FIXME: hardcoded
|
||||||
|
with open("models/vgfs/multilingual/vanillaFunGen_todo.pkl") as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
if self.wce_vgf:
|
||||||
|
# FIXME: hardcoded
|
||||||
|
with open("models/vgfs/wordclass/vanillaFunGen_todo.pkl") as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
if self.trasformer_vgf:
|
||||||
|
# FIXME: hardcoded
|
||||||
|
with open("models/vgfs/transformers/vanillaFunGen_todo.pkl") as vgf:
|
||||||
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
|
with open(os.path.join("models", "metaclassifier", "meta_todo.pkl"), "rb") as f:
|
||||||
|
metaclassifier = pickle.load(f)
|
||||||
|
with open(
|
||||||
|
os.path.join("models", "vectorizer", "vectorizer_todo.pkl"), "rb"
|
||||||
|
) as f:
|
||||||
|
vectorizer = pickle.load(f)
|
||||||
|
return first_tier_learners, metaclassifier, vectorizer
|
||||||
|
|
||||||
|
|
||||||
def get_params(optimc=False):
|
def get_params(optimc=False):
|
||||||
if not optimc:
|
if not optimc:
|
||||||
|
|
|
@ -27,11 +27,16 @@ class VanillaFunGen(ViewGen):
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
self.vectorizer = None
|
self.vectorizer = None
|
||||||
|
self.load_trained = False
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, lY):
|
||||||
|
if self.load_trained:
|
||||||
|
return self.load_trained()
|
||||||
|
|
||||||
print("- fitting VanillaFun View Generating Function")
|
print("- fitting VanillaFun View Generating Function")
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
self.doc_projector.fit(lX, lY)
|
self.doc_projector.fit(lX, lY)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
@ -57,3 +62,18 @@ class VanillaFunGen(ViewGen):
|
||||||
"first_tier_parameters": self.first_tier_parameters,
|
"first_tier_parameters": self.first_tier_parameters,
|
||||||
"n_jobs": self.n_jobs,
|
"n_jobs": self.n_jobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def save_vgf(self):
|
||||||
|
import pickle
|
||||||
|
from os.path import join
|
||||||
|
from os import makedirs
|
||||||
|
|
||||||
|
model_id = "TODO"
|
||||||
|
|
||||||
|
vgf_name = "vanillaFunGen_todo"
|
||||||
|
_basedir = join("models", "vgfs", "posteriors")
|
||||||
|
makedirs(_basedir, exist_ok=True)
|
||||||
|
_path = join(_basedir, f"{vgf_name}.pkl")
|
||||||
|
with open(_path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
return self
|
||||||
|
|
92
main.py
92
main.py
|
@ -1,19 +1,22 @@
|
||||||
from os.path import expanduser
|
import pickle
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from os.path import expanduser
|
||||||
|
from time import time
|
||||||
|
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
from dataManager.multilingualDatset import MultilingualDataset
|
from dataManager.multilingualDatset import MultilingualDataset
|
||||||
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
"""
|
||||||
|
TODO:
|
||||||
|
- a cleaner way to save the model? each VGF saved independently (together with
|
||||||
|
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
|
||||||
|
- add documentations sphinx
|
||||||
|
- zero-shot setup
|
||||||
|
|
||||||
from time import time
|
"""
|
||||||
import pickle
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: a cleaner way to save the model?
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
@ -26,7 +29,7 @@ def main(args):
|
||||||
dataset = (
|
dataset = (
|
||||||
MultilingualDataset(dataset_name="rcv1-2")
|
MultilingualDataset(dataset_name="rcv1-2")
|
||||||
.load(RCV_DATAPATH)
|
.load(RCV_DATAPATH)
|
||||||
.reduce_data(langs=["en", "it", "fr"], maxn=250)
|
.reduce_data(langs=["en", "it", "fr"], maxn=100)
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(dataset, MultilingualDataset):
|
if isinstance(dataset, MultilingualDataset):
|
||||||
|
@ -39,7 +42,7 @@ def main(args):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
if args.load_pretrained is None:
|
if not args.load_trained:
|
||||||
assert any(
|
assert any(
|
||||||
[
|
[
|
||||||
args.posteriors,
|
args.posteriors,
|
||||||
|
@ -50,48 +53,38 @@ def main(args):
|
||||||
]
|
]
|
||||||
), "At least one of VGF must be True"
|
), "At least one of VGF must be True"
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
posterior=args.posteriors,
|
posterior=args.posteriors,
|
||||||
multilingual=args.multilingual,
|
multilingual=args.multilingual,
|
||||||
wce=args.wce,
|
wce=args.wce,
|
||||||
transformer=args.transformer,
|
transformer=args.transformer,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
embed_dir="~/resources/muse_embeddings",
|
embed_dir="~/resources/muse_embeddings",
|
||||||
n_jobs=args.n_jobs,
|
n_jobs=args.n_jobs,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
transformer_name=args.transformer_name,
|
transformer_name=args.transformer_name,
|
||||||
)
|
device="cuda",
|
||||||
|
optimc=args.optimc,
|
||||||
|
load_trained=args.load_trained,
|
||||||
|
)
|
||||||
|
|
||||||
gfun.get_config()
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
gfun.fit(lX, lY)
|
# if not args.load_model:
|
||||||
|
# gfun.save()
|
||||||
|
|
||||||
# Saving Model ------------------------
|
preds = gfun.transform(lX)
|
||||||
with open("models/gfun/gfun_model.pkl", "wb") as f:
|
|
||||||
print(f"- saving model to {f.name}")
|
|
||||||
pickle.dump(gfun, f)
|
|
||||||
# -------------------------------------
|
|
||||||
|
|
||||||
preds = gfun.transform(lX)
|
train_eval = evaluate(lY, preds)
|
||||||
|
log_eval(train_eval, phase="train")
|
||||||
|
|
||||||
train_eval = evaluate(lY, preds)
|
timetr = time()
|
||||||
log_eval(train_eval, phase="train")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
|
||||||
timetr = time()
|
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
|
||||||
|
|
||||||
# Loading Model ------------------------
|
|
||||||
if args.load_pretrained is not None:
|
|
||||||
with open("models/gfun/gfun_model.pkl", "rb") as f:
|
|
||||||
print(f"- loading model from {f.name}")
|
|
||||||
gfun = pickle.load(f)
|
|
||||||
timetr = time()
|
|
||||||
# --------------------------------------
|
|
||||||
|
|
||||||
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
||||||
log_eval(test_eval, phase="test")
|
log_eval(test_eval, phase="test")
|
||||||
|
@ -102,7 +95,7 @@ def main(args):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--load_pretrained", type=str, default=None)
|
parser.add_argument("-l", "--load_trained", action="store_true")
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
parser.add_argument("--nrows", type=int, default=10000)
|
parser.add_argument("--nrows", type=int, default=10000)
|
||||||
|
@ -114,6 +107,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-w", "--wce", action="store_true")
|
parser.add_argument("-w", "--wce", action="store_true")
|
||||||
parser.add_argument("-t", "--transformer", action="store_true")
|
parser.add_argument("-t", "--transformer", action="store_true")
|
||||||
parser.add_argument("--n_jobs", type=int, default=1)
|
parser.add_argument("--n_jobs", type=int, default=1)
|
||||||
|
parser.add_argument("--optimc", action="store_true")
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
|
|
Loading…
Reference in New Issue