script for simpler inference

This commit is contained in:
Andrea Pedrotti 2023-10-29 18:15:01 +01:00
parent 5d07e579e4
commit 41ba20ad5c
4 changed files with 118 additions and 40 deletions

View File

@ -25,7 +25,9 @@ class SimpleGfunDataset:
multilabel=False,
set_tr_langs=None,
set_te_langs=None,
reduced=False
reduced=False,
only_inference=False,
labels=None
):
self.name = dataset_name
self.datadir = os.path.expanduser(datadir)
@ -33,6 +35,8 @@ class SimpleGfunDataset:
self.visual = visual
self.multilabel = multilabel
self.reduced = reduced
self.only_inference = only_inference
self.labels = labels
self.load_csv(set_tr_langs, set_te_langs)
self.print_stats()
@ -42,8 +46,8 @@ class SimpleGfunDataset:
va = 0
te = 0
for lang in self.all_langs:
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
n_tr = len(self.train_data[lang]) if lang in self.tr_langs and self.train_data is not None else 0
n_va = len(self.val_data[lang]) if lang in self.tr_langs and self.val_data is not None else 0
n_te = len(self.test_data[lang])
tr += n_tr
va += n_va
@ -52,7 +56,19 @@ class SimpleGfunDataset:
print(f"Total {'-' * 15}")
print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv_inference(self):
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
self._set_labels(test)
self._set_langs(train=None, test=test)
self.train_data = None
self.val_data = None
self.test_data = self._set_datalang(test)
return
def load_csv(self, set_tr_langs, set_te_langs):
if self.only_inference:
self.load_csv_inference()
return
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv"))
try:
stratified = "class"
@ -72,11 +88,14 @@ class SimpleGfunDataset:
return
def _set_labels(self, data):
self.labels = sorted(list(data.label.unique()))
if self.labels is not None:
self.labels = [i for i in range(self.labels)]
else:
self.labels = sorted(list(data.label.unique()))
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
self.tr_langs = set(train.lang.unique().tolist())
self.te_langs = set(test.lang.unique().tolist())
self.tr_langs = set(train.lang.unique().tolist()) if train is not None else set()
self.te_langs = set(test.lang.unique().tolist()) if test is not None else set()
if set_tr_langs is not None:
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
@ -122,12 +141,15 @@ class SimpleGfunDataset:
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
for lang in self.te_langs
}
lYte = {
lang: self.indices_to_one_hot(
indices=self.test_data[lang].label.tolist(),
n_labels=self.num_labels())
for lang in self.te_langs
}
if not self.only_inference:
lYte = {
lang: self.indices_to_one_hot(
indices=self.test_data[lang].label.tolist(),
n_labels=self.num_labels())
for lang in self.te_langs
}
else:
lYte = None
return lXte, lYte
def langs(self):

View File

@ -14,31 +14,31 @@ from gfun.vgfs.wceGen import WceGen
class GeneralizedFunnelling:
def __init__(
self,
posterior,
wce,
multilingual,
textual_transformer,
langs,
num_labels,
classification_type,
embed_dir,
n_jobs,
batch_size,
eval_batch_size,
max_length,
textual_lr,
epochs,
patience,
evaluate_step,
optimc,
device,
load_trained,
dataset_name,
probabilistic,
aggfunc,
load_meta,
posterior=True,
wce=False,
multilingual=False,
textual_transformer=False,
classification_type="multilabel",
embed_dir="embeddings/muse",
n_jobs=-1,
batch_size=32,
eval_batch_size=128,
max_length=512,
textual_lr=1e-4,
epochs=50,
patience=10,
evaluate_step=25,
optimc=True,
device="cuda:0",
load_trained=None,
probabilistic=True,
aggfunc="mean",
load_meta=False,
trained_text_trf=None,
textual_transformer_name=None,
textual_transformer_name="mbert",
):
# Setting VFGs -----------
self.posteriors_vgf = posterior
@ -344,8 +344,10 @@ class GeneralizedFunnelling:
self.save_first_tier_learners(model_id=self._model_id)
if save_meta:
_basedir = os.path.join("models", "metaclassifier")
os.makedirs(_basedir)
with open(
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
os.path.join(_basedir, f"meta_{self._model_id}.pkl"),
"wb",
) as f:
pickle.dump(self.metaclassifier, f)

48
infer.py Normal file
View File

@ -0,0 +1,48 @@
from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling
from main import save_preds
def main(args):
dataset = SimpleGfunDataset(
dataset_name="inference-dataset",
datadir=args.datadir,
multilabel=False,
only_inference=True,
labels=args.nlabels,
)
lX, _ = dataset.test()
print("Ok")
gfun = GeneralizedFunnelling(
dataset_name="inference",
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type="singlelabel",
embed_dir=args.muse_dir,
posterior=True,
multilingual=True,
textual_transformer=True,
load_trained=args.trained_gfun,
load_meta=True,
trained_text_trf=args.trained_transformer,
)
predictions = gfun.transform(lX)
save_preds(
preds=predictions,
targets=None,
config="inference"
)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
parser.add_argument("--nlabels", type=int, default=28)
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer")
args = parser.parse_args()
main(args)

18
main.py
View File

@ -48,7 +48,7 @@ def main(args):
gfun = GeneralizedFunnelling(
# dataset params ----------------------
dataset_name=dataset,
dataset_name=dataset.name,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type=args.clf_type,
@ -62,7 +62,8 @@ def main(args):
# Transformer VGF params --------------
textual_transformer=args.textual_transformer,
textual_transformer_name=args.textual_trf_name,
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
# trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
trained_text_trf="hf_models/mbert-fewshot-rai-full/checkpoint-5150",
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs,
@ -88,8 +89,9 @@ def main(args):
gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True)
# if args.load_trained is None and not args.nosave:
print("saving model")
gfun.save(save_first_tier=True, save_meta=True)
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
@ -143,12 +145,16 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
_langs = []
for lang in langs:
_preds.extend(preds[lang].argmax(axis=1).tolist())
_targets.extend(targets[lang].argmax(axis=1).tolist())
if targets is None:
_targets.extend(["na" for i in range(len(preds[lang]))])
else:
_targets.extend(targets[lang].argmax(axis=1).tolist())
_langs.extend([lang for i in range(len(preds[lang]))])
df["langs"] = _langs
df["labels"] = _targets
df["preds"] = _preds
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
print(f"- storing predictions in 'results/preds/preds.gfun.{config}.{dataset}.csv'")
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.csv", index=False)
if __name__ == "__main__":