script for simpler inference
This commit is contained in:
parent
5d07e579e4
commit
41ba20ad5c
|
|
@ -25,7 +25,9 @@ class SimpleGfunDataset:
|
|||
multilabel=False,
|
||||
set_tr_langs=None,
|
||||
set_te_langs=None,
|
||||
reduced=False
|
||||
reduced=False,
|
||||
only_inference=False,
|
||||
labels=None
|
||||
):
|
||||
self.name = dataset_name
|
||||
self.datadir = os.path.expanduser(datadir)
|
||||
|
|
@ -33,6 +35,8 @@ class SimpleGfunDataset:
|
|||
self.visual = visual
|
||||
self.multilabel = multilabel
|
||||
self.reduced = reduced
|
||||
self.only_inference = only_inference
|
||||
self.labels = labels
|
||||
self.load_csv(set_tr_langs, set_te_langs)
|
||||
self.print_stats()
|
||||
|
||||
|
|
@ -42,8 +46,8 @@ class SimpleGfunDataset:
|
|||
va = 0
|
||||
te = 0
|
||||
for lang in self.all_langs:
|
||||
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
|
||||
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
|
||||
n_tr = len(self.train_data[lang]) if lang in self.tr_langs and self.train_data is not None else 0
|
||||
n_va = len(self.val_data[lang]) if lang in self.tr_langs and self.val_data is not None else 0
|
||||
n_te = len(self.test_data[lang])
|
||||
tr += n_tr
|
||||
va += n_va
|
||||
|
|
@ -51,8 +55,20 @@ class SimpleGfunDataset:
|
|||
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
|
||||
print(f"Total {'-' * 15}")
|
||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||
|
||||
def load_csv_inference(self):
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||
self._set_labels(test)
|
||||
self._set_langs(train=None, test=test)
|
||||
self.train_data = None
|
||||
self.val_data = None
|
||||
self.test_data = self._set_datalang(test)
|
||||
return
|
||||
|
||||
def load_csv(self, set_tr_langs, set_te_langs):
|
||||
if self.only_inference:
|
||||
self.load_csv_inference()
|
||||
return
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv"))
|
||||
try:
|
||||
stratified = "class"
|
||||
|
|
@ -72,11 +88,14 @@ class SimpleGfunDataset:
|
|||
return
|
||||
|
||||
def _set_labels(self, data):
|
||||
self.labels = sorted(list(data.label.unique()))
|
||||
if self.labels is not None:
|
||||
self.labels = [i for i in range(self.labels)]
|
||||
else:
|
||||
self.labels = sorted(list(data.label.unique()))
|
||||
|
||||
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
||||
self.tr_langs = set(train.lang.unique().tolist())
|
||||
self.te_langs = set(test.lang.unique().tolist())
|
||||
self.tr_langs = set(train.lang.unique().tolist()) if train is not None else set()
|
||||
self.te_langs = set(test.lang.unique().tolist()) if test is not None else set()
|
||||
if set_tr_langs is not None:
|
||||
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
|
||||
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
|
||||
|
|
@ -122,12 +141,15 @@ class SimpleGfunDataset:
|
|||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||
for lang in self.te_langs
|
||||
}
|
||||
lYte = {
|
||||
lang: self.indices_to_one_hot(
|
||||
indices=self.test_data[lang].label.tolist(),
|
||||
n_labels=self.num_labels())
|
||||
for lang in self.te_langs
|
||||
}
|
||||
if not self.only_inference:
|
||||
lYte = {
|
||||
lang: self.indices_to_one_hot(
|
||||
indices=self.test_data[lang].label.tolist(),
|
||||
n_labels=self.num_labels())
|
||||
for lang in self.te_langs
|
||||
}
|
||||
else:
|
||||
lYte = None
|
||||
return lXte, lYte
|
||||
|
||||
def langs(self):
|
||||
|
|
|
|||
|
|
@ -14,31 +14,31 @@ from gfun.vgfs.wceGen import WceGen
|
|||
class GeneralizedFunnelling:
|
||||
def __init__(
|
||||
self,
|
||||
posterior,
|
||||
wce,
|
||||
multilingual,
|
||||
textual_transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
classification_type,
|
||||
embed_dir,
|
||||
n_jobs,
|
||||
batch_size,
|
||||
eval_batch_size,
|
||||
max_length,
|
||||
textual_lr,
|
||||
epochs,
|
||||
patience,
|
||||
evaluate_step,
|
||||
optimc,
|
||||
device,
|
||||
load_trained,
|
||||
dataset_name,
|
||||
probabilistic,
|
||||
aggfunc,
|
||||
load_meta,
|
||||
posterior=True,
|
||||
wce=False,
|
||||
multilingual=False,
|
||||
textual_transformer=False,
|
||||
classification_type="multilabel",
|
||||
embed_dir="embeddings/muse",
|
||||
n_jobs=-1,
|
||||
batch_size=32,
|
||||
eval_batch_size=128,
|
||||
max_length=512,
|
||||
textual_lr=1e-4,
|
||||
epochs=50,
|
||||
patience=10,
|
||||
evaluate_step=25,
|
||||
optimc=True,
|
||||
device="cuda:0",
|
||||
load_trained=None,
|
||||
probabilistic=True,
|
||||
aggfunc="mean",
|
||||
load_meta=False,
|
||||
trained_text_trf=None,
|
||||
textual_transformer_name=None,
|
||||
textual_transformer_name="mbert",
|
||||
):
|
||||
# Setting VFGs -----------
|
||||
self.posteriors_vgf = posterior
|
||||
|
|
@ -344,8 +344,10 @@ class GeneralizedFunnelling:
|
|||
self.save_first_tier_learners(model_id=self._model_id)
|
||||
|
||||
if save_meta:
|
||||
_basedir = os.path.join("models", "metaclassifier")
|
||||
os.makedirs(_basedir)
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"),
|
||||
os.path.join(_basedir, f"meta_{self._model_id}.pkl"),
|
||||
"wb",
|
||||
) as f:
|
||||
pickle.dump(self.metaclassifier, f)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
from dataManager.gFunDataset import SimpleGfunDataset
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
from main import save_preds
|
||||
|
||||
def main(args):
|
||||
dataset = SimpleGfunDataset(
|
||||
dataset_name="inference-dataset",
|
||||
datadir=args.datadir,
|
||||
multilabel=False,
|
||||
only_inference=True,
|
||||
labels=args.nlabels,
|
||||
)
|
||||
|
||||
lX, _ = dataset.test()
|
||||
print("Ok")
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
dataset_name="inference",
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type="singlelabel",
|
||||
embed_dir=args.muse_dir,
|
||||
posterior=True,
|
||||
multilingual=True,
|
||||
textual_transformer=True,
|
||||
load_trained=args.trained_gfun,
|
||||
load_meta=True,
|
||||
trained_text_trf=args.trained_transformer,
|
||||
)
|
||||
|
||||
predictions = gfun.transform(lX)
|
||||
save_preds(
|
||||
preds=predictions,
|
||||
targets=None,
|
||||
config="inference"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
|
||||
parser.add_argument("--nlabels", type=int, default=28)
|
||||
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
||||
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||
parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
18
main.py
18
main.py
|
|
@ -48,7 +48,7 @@ def main(args):
|
|||
|
||||
gfun = GeneralizedFunnelling(
|
||||
# dataset params ----------------------
|
||||
dataset_name=dataset,
|
||||
dataset_name=dataset.name,
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type=args.clf_type,
|
||||
|
|
@ -62,7 +62,8 @@ def main(args):
|
|||
# Transformer VGF params --------------
|
||||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.textual_trf_name,
|
||||
trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
||||
# trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350",
|
||||
trained_text_trf="hf_models/mbert-fewshot-rai-full/checkpoint-5150",
|
||||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
|
|
@ -88,8 +89,9 @@ def main(args):
|
|||
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is None and not args.nosave:
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
# if args.load_trained is None and not args.nosave:
|
||||
print("saving model")
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
|
@ -143,12 +145,16 @@ def save_preds(preds, targets, config="unk", dataset="unk"):
|
|||
_langs = []
|
||||
for lang in langs:
|
||||
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
||||
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||
if targets is None:
|
||||
_targets.extend(["na" for i in range(len(preds[lang]))])
|
||||
else:
|
||||
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||
_langs.extend([lang for i in range(len(preds[lang]))])
|
||||
df["langs"] = _langs
|
||||
df["labels"] = _targets
|
||||
df["preds"] = _preds
|
||||
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False)
|
||||
print(f"- storing predictions in 'results/preds/preds.gfun.{config}.{dataset}.csv'")
|
||||
df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.csv", index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue