fixed inference script
This commit is contained in:
parent
4615bc3857
commit
35cc32e541
|
|
@ -10,16 +10,16 @@ from tqdm import tqdm
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
# from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
from dataManager.glamiDataset import get_dataframe
|
# from dataManager.glamiDataset import get_dataframe
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
# from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
||||||
|
|
||||||
class SimpleGfunDataset:
|
class SimpleGfunDataset:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_name=None,
|
dataset_name=None,
|
||||||
datadir="~/datasets/rai/csv/",
|
datapath="~/datasets/rai/csv/test.csv",
|
||||||
textual=True,
|
textual=True,
|
||||||
visual=False,
|
visual=False,
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
|
|
@ -30,7 +30,7 @@ class SimpleGfunDataset:
|
||||||
labels=None
|
labels=None
|
||||||
):
|
):
|
||||||
self.name = dataset_name
|
self.name = dataset_name
|
||||||
self.datadir = os.path.expanduser(datadir)
|
self.datadir = os.path.expanduser(datapath)
|
||||||
self.textual = textual
|
self.textual = textual
|
||||||
self.visual = visual
|
self.visual = visual
|
||||||
self.multilabel = multilabel
|
self.multilabel = multilabel
|
||||||
|
|
@ -41,7 +41,7 @@ class SimpleGfunDataset:
|
||||||
self.print_stats()
|
self.print_stats()
|
||||||
|
|
||||||
def print_stats(self):
|
def print_stats(self):
|
||||||
print(f"Dataset statistics {'-' * 15}")
|
print(f"Dataset statistics\n{'-' * 15}")
|
||||||
tr = 0
|
tr = 0
|
||||||
va = 0
|
va = 0
|
||||||
te = 0
|
te = 0
|
||||||
|
|
@ -53,11 +53,12 @@ class SimpleGfunDataset:
|
||||||
va += n_va
|
va += n_va
|
||||||
te += n_te
|
te += n_te
|
||||||
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
|
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
|
||||||
print(f"Total {'-' * 15}")
|
print(f"{'-' * 15}\nTotal\n{'-' * 15}")
|
||||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||||
|
|
||||||
def load_csv_inference(self):
|
def load_csv_inference(self):
|
||||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||||
|
test = pd.read_csv(self.datadir)
|
||||||
self._set_labels(test)
|
self._set_labels(test)
|
||||||
self._set_langs(train=None, test=test)
|
self._set_langs(train=None, test=test)
|
||||||
self.train_data = None
|
self.train_data = None
|
||||||
|
|
@ -137,7 +138,10 @@ class SimpleGfunDataset:
|
||||||
def test(self, mask_number=False, target_as_csr=False):
|
def test(self, mask_number=False, target_as_csr=False):
|
||||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||||
lXte = {
|
lXte = {
|
||||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
lang: {
|
||||||
|
"text": apply_mask(self.test_data[lang].text.tolist()),
|
||||||
|
"id": self.test_data[lang].id.tolist()
|
||||||
|
}
|
||||||
for lang in self.te_langs
|
for lang in self.te_langs
|
||||||
}
|
}
|
||||||
if not self.only_inference:
|
if not self.only_inference:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def load_from_pickle(path, dataset_name, nrows):
|
||||||
def get_dataset(datasetp_path, args):
|
def get_dataset(datasetp_path, args):
|
||||||
dataset = SimpleGfunDataset(
|
dataset = SimpleGfunDataset(
|
||||||
dataset_name="rai",
|
dataset_name="rai",
|
||||||
datadir=datasetp_path,
|
datapath=datasetp_path,
|
||||||
textual=True,
|
textual=True,
|
||||||
visual=False,
|
visual=False,
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
|
|
|
||||||
|
|
@ -240,8 +240,9 @@ class GeneralizedFunnelling:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX, output_ids=False):
|
||||||
projections = []
|
projections = []
|
||||||
|
l_ids = {}
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
l_posteriors = vgf.transform(lX)
|
l_posteriors = vgf.transform(lX)
|
||||||
projections.append(l_posteriors)
|
projections.append(l_posteriors)
|
||||||
|
|
@ -250,7 +251,11 @@ class GeneralizedFunnelling:
|
||||||
if self.clf_type == "singlelabel":
|
if self.clf_type == "singlelabel":
|
||||||
for lang, preds in l_out.items():
|
for lang, preds in l_out.items():
|
||||||
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||||
return l_out
|
l_ids[lang] = lX[lang]["id"]
|
||||||
|
if output_ids:
|
||||||
|
return l_out, l_ids
|
||||||
|
else:
|
||||||
|
return l_out
|
||||||
|
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
return self.fit(lX, lY).transform(lX)
|
return self.fit(lX, lY).transform(lX)
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ class TfidfVectorizerMultilingual:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def update(self, X, lang):
|
def update_vectorizer(self, X, lang):
|
||||||
self.langs.append(lang)
|
self.langs.append(lang)
|
||||||
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
|
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
|
||||||
return self
|
return self
|
||||||
|
|
@ -121,8 +121,9 @@ class TfidfVectorizerMultilingual:
|
||||||
for in_l in in_langs:
|
for in_l in in_langs:
|
||||||
if in_l not in self.langs:
|
if in_l not in self.langs:
|
||||||
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
|
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
|
||||||
self.update(X=lX[in_l], lang=in_l)
|
self.update_vectorizer(X=lX[in_l], lang=in_l)
|
||||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
|
# return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
|
||||||
|
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in in_langs}
|
||||||
|
|
||||||
def fit_transform(self, lX, ly=None):
|
def fit_transform(self, lX, ly=None):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,11 @@ class MultilingualGen(ViewGen):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
_langs = lX.keys()
|
||||||
lX = self.vectorizer.transform(lX)
|
lX = self.vectorizer.transform(lX)
|
||||||
if self.langs != sorted(self.vectorizer.vectorizer.keys()):
|
if _langs != sorted(self.vectorizer.vectorizer.keys()):
|
||||||
# new_langs = set(self.vectorizer.vectorizer.keys()) - set(self.langs)
|
"""Loading word-embeddings for unseen languages at training time (zero-shot scenario),
|
||||||
|
excluding (exclude=old_langs) already loaded matrices."""
|
||||||
old_langs = self.langs
|
old_langs = self.langs
|
||||||
self.langs = sorted(self.vectorizer.vectorizer.keys())
|
self.langs = sorted(self.vectorizer.vectorizer.keys())
|
||||||
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
|
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
|
||||||
|
|
@ -66,14 +68,14 @@ class MultilingualGen(ViewGen):
|
||||||
|
|
||||||
XdotMulti = Parallel(n_jobs=self.n_jobs)(
|
XdotMulti = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
|
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
|
||||||
for lang in self.langs
|
for lang in _langs
|
||||||
)
|
)
|
||||||
lZ = {lang: XdotMulti[i] for i, lang in enumerate(self.langs)}
|
lZ = {lang: XdotMulti[i] for i, lang in enumerate(_langs)}
|
||||||
lZ = _normalize(lZ, l2=True)
|
lZ = _normalize(lZ, l2=True)
|
||||||
if self.probabilistic and self.fitted:
|
if self.probabilistic and self.fitted:
|
||||||
lZ = self.feature2posterior_projector.transform(lZ)
|
lZ = self.feature2posterior_projector.transform(lZ)
|
||||||
return lZ
|
return lZ
|
||||||
|
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
return self.fit(lX, lY).transform(lX)
|
return self.fit(lX, lY).transform(lX)
|
||||||
|
|
||||||
|
|
|
||||||
73
infer.py
73
infer.py
|
|
@ -1,20 +1,21 @@
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from dataManager.gFunDataset import SimpleGfunDataset
|
from dataManager.gFunDataset import SimpleGfunDataset
|
||||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
from main import save_preds
|
|
||||||
|
|
||||||
DATASET_NAME = "inference-dataset" # can be anything
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = SimpleGfunDataset(
|
dataset = SimpleGfunDataset(
|
||||||
dataset_name=DATASET_NAME,
|
dataset_name=Path(args.datapath).stem,
|
||||||
datadir=args.datadir,
|
datapath=args.datapath,
|
||||||
multilabel=False,
|
multilabel=False,
|
||||||
only_inference=True,
|
only_inference=True,
|
||||||
labels=args.nlabels,
|
labels=args.nlabels,
|
||||||
)
|
)
|
||||||
|
|
||||||
lX, _ = dataset.test()
|
lX, _ = dataset.test()
|
||||||
print("Ok")
|
|
||||||
|
|
||||||
gfun = GeneralizedFunnelling(
|
gfun = GeneralizedFunnelling(
|
||||||
dataset_name=dataset.get_name(),
|
dataset_name=dataset.get_name(),
|
||||||
|
|
@ -29,17 +30,71 @@ def main(args):
|
||||||
load_meta=True,
|
load_meta=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions = gfun.transform(lX)
|
predictions, ids = gfun.transform(lX, output_ids=True)
|
||||||
save_preds(
|
save_inference_preds(
|
||||||
preds=predictions,
|
preds=predictions,
|
||||||
|
doc_ids=ids,
|
||||||
|
dataset_name=dataset.get_name(),
|
||||||
targets=None,
|
targets=None,
|
||||||
config=f"Inference dataset: {dataset.get_name()}"
|
category_mapper="models/category_mappers/rai-mapping.csv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
preds : Dict[str: np.array]
|
||||||
|
Predictions produced by generalized-funnelling.
|
||||||
|
dataset_name: str
|
||||||
|
Dataset name used as output file name. File is stored in directory defined by `output_dir`
|
||||||
|
argument e.g. "<output_dir>/<dataset_name>.csv"
|
||||||
|
doc_ids: Dict[str: List[str]]
|
||||||
|
Dictionary storing list of document ids (as defined in the csv input file)
|
||||||
|
targets: Dict[str: np.array]
|
||||||
|
If availabel, target true labels will be written to output file to ease performance evaluation.
|
||||||
|
(default = None)
|
||||||
|
category_mapper: Path
|
||||||
|
Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer).
|
||||||
|
If not None, gFun predictions will be converetd and stored as target string classes.
|
||||||
|
(default=None)
|
||||||
|
output_dir: Path
|
||||||
|
Dir where to store output csv file.
|
||||||
|
(default = results/inference-preds)
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
df = pd.DataFrame()
|
||||||
|
langs = sorted(preds.keys())
|
||||||
|
_ids = []
|
||||||
|
_preds = []
|
||||||
|
_targets = []
|
||||||
|
_langs = []
|
||||||
|
for lang in langs:
|
||||||
|
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
||||||
|
_langs.extend([lang for i in range(len(preds[lang]))])
|
||||||
|
_ids.extend(doc_ids[lang])
|
||||||
|
|
||||||
|
df["doc_id"] = _ids
|
||||||
|
df["document_language"] = _langs
|
||||||
|
df["gfun_prediction"] = _preds
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for lang in langs:
|
||||||
|
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||||
|
df["document_true_label"] = _targets
|
||||||
|
|
||||||
|
if category_mapper is not None:
|
||||||
|
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
||||||
|
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
||||||
|
|
||||||
|
output_file = f"{output_dir}/{dataset_name}.csv"
|
||||||
|
print(f"Storing predicitons in: {output_file}")
|
||||||
|
df.to_csv(output_file, index=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
|
parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified")
|
||||||
parser.add_argument("--nlabels", type=int, default=28)
|
parser.add_argument("--nlabels", type=int, default=28)
|
||||||
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
||||||
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue