script for simpler inference
This commit is contained in:
parent
5d07e579e4
commit
41ba20ad5c
|
|
@ -25,7 +25,9 @@ class SimpleGfunDataset:
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
set_tr_langs=None,
|
set_tr_langs=None,
|
||||||
set_te_langs=None,
|
set_te_langs=None,
|
||||||
reduced=False
|
reduced=False,
|
||||||
|
only_inference=False,
|
||||||
|
labels=None
|
||||||
):
|
):
|
||||||
self.name = dataset_name
|
self.name = dataset_name
|
||||||
self.datadir = os.path.expanduser(datadir)
|
self.datadir = os.path.expanduser(datadir)
|
||||||
|
|
@ -33,6 +35,8 @@ class SimpleGfunDataset:
|
||||||
self.visual = visual
|
self.visual = visual
|
||||||
self.multilabel = multilabel
|
self.multilabel = multilabel
|
||||||
self.reduced = reduced
|
self.reduced = reduced
|
||||||
|
self.only_inference = only_inference
|
||||||
|
self.labels = labels
|
||||||
self.load_csv(set_tr_langs, set_te_langs)
|
self.load_csv(set_tr_langs, set_te_langs)
|
||||||
self.print_stats()
|
self.print_stats()
|
||||||
|
|
||||||
|
|
@ -42,8 +46,8 @@ class SimpleGfunDataset:
|
||||||
va = 0
|
va = 0
|
||||||
te = 0
|
te = 0
|
||||||
for lang in self.all_langs:
|
for lang in self.all_langs:
|
||||||
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
|
n_tr = len(self.train_data[lang]) if lang in self.tr_langs and self.train_data is not None else 0
|
||||||
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
|
n_va = len(self.val_data[lang]) if lang in self.tr_langs and self.val_data is not None else 0
|
||||||
n_te = len(self.test_data[lang])
|
n_te = len(self.test_data[lang])
|
||||||
tr += n_tr
|
tr += n_tr
|
||||||
va += n_va
|
va += n_va
|
||||||
|
|
@ -52,7 +56,19 @@ class SimpleGfunDataset:
|
||||||
print(f"Total {'-' * 15}")
|
print(f"Total {'-' * 15}")
|
||||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||||
|
|
||||||
|
def load_csv_inference(self):
|
||||||
|
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||||
|
self._set_labels(test)
|
||||||
|
self._set_langs(train=None, test=test)
|
||||||
|
self.train_data = None
|
||||||
|
self.val_data = None
|
||||||
|
self.test_data = self._set_datalang(test)
|
||||||
|
return
|
||||||
|
|
||||||
def load_csv(self, set_tr_langs, set_te_langs):
|
def load_csv(self, set_tr_langs, set_te_langs):
|
||||||
|
if self.only_inference:
|
||||||
|
self.load_csv_inference()
|
||||||
|
return
|
||||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv"))
|
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv"))
|
||||||
try:
|
try:
|
||||||
stratified = "class"
|
stratified = "class"
|
||||||
|
|
@ -72,11 +88,14 @@ class SimpleGfunDataset:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _set_labels(self, data):
|
def _set_labels(self, data):
|
||||||
|
if self.labels is not None:
|
||||||
|
self.labels = [i for i in range(self.labels)]
|
||||||
|
else:
|
||||||
self.labels = sorted(list(data.label.unique()))
|
self.labels = sorted(list(data.label.unique()))
|
||||||
|
|
||||||
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
||||||
self.tr_langs = set(train.lang.unique().tolist())
|
self.tr_langs = set(train.lang.unique().tolist()) if train is not None else set()
|
||||||
self.te_langs = set(test.lang.unique().tolist())
|
self.te_langs = set(test.lang.unique().tolist()) if test is not None else set()
|
||||||
if set_tr_langs is not None:
|
if set_tr_langs is not None:
|
||||||
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
|
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
|
||||||
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
|
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
|
||||||
|
|
@ -122,12 +141,15 @@ class SimpleGfunDataset:
|
||||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||||
for lang in self.te_langs
|
for lang in self.te_langs
|
||||||
}
|
}
|
||||||
|
if not self.only_inference:
|
||||||
lYte = {
|
lYte = {
|
||||||
lang: self.indices_to_one_hot(
|
lang: self.indices_to_one_hot(
|
||||||
indices=self.test_data[lang].label.tolist(),
|
indices=self.test_data[lang].label.tolist(),
|
||||||
n_labels=self.num_labels())
|
n_labels=self.num_labels())
|
||||||
for lang in self.te_langs
|
for lang in self.te_langs
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
lYte = None
|
||||||
return lXte, lYte
|
return lXte, lYte
|
||||||
|
|
||||||
def langs(self):
|
def langs(self):
|
||||||
|
|
|
||||||
|
|
@ -14,31 +14,31 @@ from gfun.vgfs.wceGen import WceGen
|
||||||
class GeneralizedFunnelling:
|
class GeneralizedFunnelling:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
posterior,
|
|
||||||
wce,
|
|
||||||
multilingual,
|
|
||||||
textual_transformer,
|
|
||||||
langs,
|
langs,
|
||||||
num_labels,
|
num_labels,
|
||||||
classification_type,
|
|
||||||
embed_dir,
|
|
||||||
n_jobs,
|
|
||||||
batch_size,
|
|
||||||
eval_batch_size,
|
|
||||||
max_length,
|
|
||||||
textual_lr,
|
|
||||||
epochs,
|
|
||||||
patience,
|
|
||||||
evaluate_step,
|
|
||||||
optimc,
|
|
||||||
device,
|
|
||||||
load_trained,
|
|
||||||
dataset_name,
|
dataset_name,
|
||||||
probabilistic,
|
posterior=True,
|
||||||
aggfunc,
|
wce=False,
|
||||||
load_meta,
|
multilingual=False,
|
||||||
|
textual_transformer=False,
|
||||||
|
classification_type="multilabel",
|
||||||
|
embed_dir="embeddings/muse",
|
||||||
|
n_jobs=-1,
|
||||||
|
batch_size=32,
|
||||||
|
eval_batch_size=128,
|
||||||
|
max_length=512,
|
||||||
|
textual_lr=1e-4,
|
||||||
|
epochs=50,
|
||||||
|
patience=10,
|
||||||
|
evaluate_step=25,
|
||||||
|
optimc=True,
|
||||||
|
device="cuda:0",
|
||||||
|
load_trained=None,
|
||||||
|
probabilistic=True,
|
||||||
|
aggfunc="mean",
|
||||||
|
load_meta=False,
|
||||||
trained_text_trf=None,
|
trained_text_trf=None,
|
||||||
textual_transformer_name=None,
|
textual_transformer_name="mbert",
|
||||||
):
|
):
|
||||||
# Setting VFGs -----------
|
# Setting VFGs -----------
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
|
|
@ -344,8 +344,10 @@ class GeneralizedFunnelling:
|
||||||
self.save_first_tier_learners(model_id=self._model_id)
|
self.save_first_tier_learners(model_id=self._model_id)
|
||||||
|
|
||||||
if save_meta:
|
if save_meta:
|
||||||
|
_basedir = os.path.join("models", "metaclassifier")
|
||||||
|
os.makedirs(_basedir)
|
||||||
with open(
|
with open(
|
||||||
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
|
os.path.join(_basedir, f"meta_{self._model_id}.pkl"),
|
||||||
"wb",
|
"wb",
|
||||||
) as f:
|
) as f:
|
||||||
pickle.dump(self.metaclassifier, f)
|
pickle.dump(self.metaclassifier, f)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from dataManager.gFunDataset import SimpleGfunDataset
|
||||||
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
from main import save_preds
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
dataset = SimpleGfunDataset(
|
||||||
|
dataset_name="inference-dataset",
|
||||||
|
datadir=args.datadir,
|
||||||
|
multilabel=False,
|
||||||
|
only_inference=True,
|
||||||
|
labels=args.nlabels,
|
||||||
|
)
|
||||||
|
|
||||||
|
lX, _ = dataset.test()
|
||||||
|
print("Ok")
|
||||||
|
|
||||||
|
gfun = GeneralizedFunnelling(
|
||||||
|
dataset_name="inference",
|
||||||
|
langs=dataset.langs(),
|
||||||
|
num_labels=dataset.num_labels(),
|
||||||
|
classification_type="singlelabel",
|
||||||
|
embed_dir=args.muse_dir,
|
||||||
|
posterior=True,
|
||||||
|
multilingual=True,
|
||||||
|
textual_transformer=True,
|
||||||
|
load_trained=args.trained_gfun,
|
||||||
|
load_meta=True,
|
||||||
|
trained_text_trf=args.trained_transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions = gfun.transform(lX)
|
||||||
|
save_preds(
|
||||||
|
preds=predictions,
|
||||||
|
targets=None,
|
||||||
|
config="inference"
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
|
||||||
|
parser.add_argument("--nlabels", type=int, default=28)
|
||||||
|
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
||||||
|
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||||
|
parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
14
main.py
14
main.py
|
|
@ -48,7 +48,7 @@ def main(args):
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
# dataset params ----------------------
|
# dataset params ----------------------
|
||||||
dataset_name=dataset,
|
dataset_name=dataset.name,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
num_labels=dataset.num_labels(),
|
num_labels=dataset.num_labels(),
|
||||||
classification_type=args.clf_type,
|
classification_type=args.clf_type,
|
||||||
|
|
@ -62,7 +62,8 @@ def main(args):
|
||||||
# Transformer VGF params --------------
|
# Transformer VGF params --------------
|
||||||
textual_transformer=args.textual_transformer,
|
textual_transformer=args.textual_transformer,
|
||||||
textual_transformer_name=args.textual_trf_name,
|
textual_transformer_name=args.textual_trf_name,
|
||||||
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
# trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
||||||
|
trained_text_trf="hf_models/mbert-fewshot-rai-full/checkpoint-5150",
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
eval_batch_size=args.eval_batch_size,
|
eval_batch_size=args.eval_batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
|
|
@ -88,7 +89,8 @@ def main(args):
|
||||||
|
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None and not args.nosave:
|
# if args.load_trained is None and not args.nosave:
|
||||||
|
print("saving model")
|
||||||
gfun.save(save_first_tier=True, save_meta=True)
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
timetr = time()
|
timetr = time()
|
||||||
|
|
@ -143,12 +145,16 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
|
||||||
_langs = []
|
_langs = []
|
||||||
for lang in langs:
|
for lang in langs:
|
||||||
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
||||||
|
if targets is None:
|
||||||
|
_targets.extend(["na" for i in range(len(preds[lang]))])
|
||||||
|
else:
|
||||||
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||||
_langs.extend([lang for i in range(len(preds[lang]))])
|
_langs.extend([lang for i in range(len(preds[lang]))])
|
||||||
df["langs"] = _langs
|
df["langs"] = _langs
|
||||||
df["labels"] = _targets
|
df["labels"] = _targets
|
||||||
df["preds"] = _preds
|
df["preds"] = _preds
|
||||||
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
|
print(f"- storing predictions in 'results/preds/preds.gfun.{config}.{dataset}.csv'")
|
||||||
|
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.csv", index=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue