diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 1f38917..adf38fa 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -25,7 +25,9 @@ class SimpleGfunDataset: multilabel=False, set_tr_langs=None, set_te_langs=None, - reduced=False + reduced=False, + only_inference=False, + labels=None ): self.name = dataset_name self.datadir = os.path.expanduser(datadir) @@ -33,6 +35,8 @@ class SimpleGfunDataset: self.visual = visual self.multilabel = multilabel self.reduced = reduced + self.only_inference = only_inference + self.labels = labels self.load_csv(set_tr_langs, set_te_langs) self.print_stats() @@ -42,8 +46,8 @@ class SimpleGfunDataset: va = 0 te = 0 for lang in self.all_langs: - n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0 - n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0 + n_tr = len(self.train_data[lang]) if lang in self.tr_langs and self.train_data is not None else 0 + n_va = len(self.val_data[lang]) if lang in self.tr_langs and self.val_data is not None else 0 n_te = len(self.test_data[lang]) tr += n_tr va += n_va @@ -51,8 +55,20 @@ class SimpleGfunDataset: print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}") print(f"Total {'-' * 15}") print(f"tr: {tr} - va: {va} - te: {te}") + + def load_csv_inference(self): + test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv")) + self._set_labels(test) + self._set_langs(train=None, test=test) + self.train_data = None + self.val_data = None + self.test_data = self._set_datalang(test) + return def load_csv(self, set_tr_langs, set_te_langs): + if self.only_inference: + self.load_csv_inference() + return _data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv")) try: stratified = "class" @@ -72,11 +88,14 @@ class SimpleGfunDataset: return def _set_labels(self, data): - self.labels = sorted(list(data.label.unique())) + if self.labels is not None: + self.labels = [i for i in range(self.labels)] + else: + self.labels = sorted(list(data.label.unique())) def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None): - self.tr_langs = set(train.lang.unique().tolist()) - self.te_langs = set(test.lang.unique().tolist()) + self.tr_langs = set(train.lang.unique().tolist()) if train is not None else set() + self.te_langs = set(test.lang.unique().tolist()) if test is not None else set() if set_tr_langs is not None: print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]") self.tr_langs = self.tr_langs.intersection(set(set_tr_langs)) @@ -122,12 +141,15 @@ class SimpleGfunDataset: lang: {"text": apply_mask(self.test_data[lang].text.tolist())} for lang in self.te_langs } - lYte = { - lang: self.indices_to_one_hot( - indices=self.test_data[lang].label.tolist(), - n_labels=self.num_labels()) - for lang in self.te_langs - } + if not self.only_inference: + lYte = { + lang: self.indices_to_one_hot( + indices=self.test_data[lang].label.tolist(), + n_labels=self.num_labels()) + for lang in self.te_langs + } + else: + lYte = None return lXte, lYte def langs(self): diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 67b9632..d4a9fc2 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -14,31 +14,31 @@ from gfun.vgfs.wceGen import WceGen class GeneralizedFunnelling: def __init__( self, - posterior, - wce, - multilingual, - textual_transformer, langs, num_labels, - classification_type, - embed_dir, - n_jobs, - batch_size, - eval_batch_size, - max_length, - textual_lr, - epochs, - patience, - evaluate_step, - optimc, - device, - load_trained, dataset_name, - probabilistic, - aggfunc, - load_meta, + posterior=True, + wce=False, + multilingual=False, + textual_transformer=False, + classification_type="multilabel", + embed_dir="embeddings/muse", + n_jobs=-1, + batch_size=32, + eval_batch_size=128, + max_length=512, + textual_lr=1e-4, + epochs=50, + patience=10, + evaluate_step=25, + optimc=True, + device="cuda:0", + load_trained=None, + probabilistic=True, + aggfunc="mean", + load_meta=False, trained_text_trf=None, - textual_transformer_name=None, + textual_transformer_name="mbert", ): # Setting VFGs ----------- self.posteriors_vgf = posterior @@ -344,8 +344,10 @@ class GeneralizedFunnelling: self.save_first_tier_learners(model_id=self._model_id) if save_meta: + _basedir = os.path.join("models", "metaclassifier") + os.makedirs(_basedir) with open( - os.path.join("models", "metaclassifier", f"meta_{self._model_id}.pkl"), + os.path.join(_basedir, f"meta_{self._model_id}.pkl"), "wb", ) as f: pickle.dump(self.metaclassifier, f) diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b15f94d --- /dev/null +++ b/infer.py @@ -0,0 +1,48 @@ +from dataManager.gFunDataset import SimpleGfunDataset +from gfun.generalizedFunnelling import GeneralizedFunnelling +from main import save_preds + +def main(args): + dataset = SimpleGfunDataset( + dataset_name="inference-dataset", + datadir=args.datadir, + multilabel=False, + only_inference=True, + labels=args.nlabels, + ) + + lX, _ = dataset.test() + print("Ok") + + gfun = GeneralizedFunnelling( + dataset_name="inference", + langs=dataset.langs(), + num_labels=dataset.num_labels(), + classification_type="singlelabel", + embed_dir=args.muse_dir, + posterior=True, + multilingual=True, + textual_transformer=True, + load_trained=args.trained_gfun, + load_meta=True, + trained_text_trf=args.trained_transformer, + ) + + predictions = gfun.transform(lX) + save_preds( + preds=predictions, + targets=None, + config="inference" + ) + +if __name__ == "__main__": + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!") + parser.add_argument("--nlabels", type=int, default=28) + parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") + parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") + parser.add_argument("--trained_transformer", default="hf_models/mbert-fewshot-rai-full/checkpoint-5150-small", help="path to fine-tuned transformer") + args = parser.parse_args() + main(args) + diff --git a/main.py b/main.py index 56affbe..8fa47a4 100644 --- a/main.py +++ b/main.py @@ -48,7 +48,7 @@ def main(args): gfun = GeneralizedFunnelling( # dataset params ---------------------- - dataset_name=dataset, + dataset_name=dataset.name, langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type=args.clf_type, @@ -62,7 +62,8 @@ def main(args): # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.textual_trf_name, - trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350", + # trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350", + trained_text_trf="hf_models/mbert-fewshot-rai-full/checkpoint-5150", batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, @@ -88,8 +89,9 @@ def main(args): gfun.fit(lX, lY) - if args.load_trained is None and not args.nosave: - gfun.save(save_first_tier=True, save_meta=True) + # if args.load_trained is None and not args.nosave: + print("saving model") + gfun.save(save_first_tier=True, save_meta=True) timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") @@ -143,12 +145,16 @@ def save_preds(preds, targets, config="unk", dataset="unk"): _langs = [] for lang in langs: _preds.extend(preds[lang].argmax(axis=1).tolist()) - _targets.extend(targets[lang].argmax(axis=1).tolist()) + if targets is None: + _targets.extend(["na" for i in range(len(preds[lang]))]) + else: + _targets.extend(targets[lang].argmax(axis=1).tolist()) _langs.extend([lang for i in range(len(preds[lang]))]) df["langs"] = _langs df["labels"] = _targets df["preds"] = _preds - df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.correct.csv", index=False) + print(f"- storing predictions in 'results/preds/preds.gfun.{config}.{dataset}.csv'") + df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.csv", index=False) if __name__ == "__main__":