fixed inference script

This commit is contained in:
Andrea Pedrotti 2024-03-12 11:38:12 +01:00
parent 4615bc3857
commit 35cc32e541
6 changed files with 96 additions and 29 deletions

View File

@ -10,16 +10,16 @@ from tqdm import tqdm
import pandas as pd import pandas as pd
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer # from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe # from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset # from dataManager.multilingualDataset import MultilingualDataset
class SimpleGfunDataset: class SimpleGfunDataset:
def __init__( def __init__(
self, self,
dataset_name=None, dataset_name=None,
datadir="~/datasets/rai/csv/", datapath="~/datasets/rai/csv/test.csv",
textual=True, textual=True,
visual=False, visual=False,
multilabel=False, multilabel=False,
@ -30,7 +30,7 @@ class SimpleGfunDataset:
labels=None labels=None
): ):
self.name = dataset_name self.name = dataset_name
self.datadir = os.path.expanduser(datadir) self.datadir = os.path.expanduser(datapath)
self.textual = textual self.textual = textual
self.visual = visual self.visual = visual
self.multilabel = multilabel self.multilabel = multilabel
@ -41,7 +41,7 @@ class SimpleGfunDataset:
self.print_stats() self.print_stats()
def print_stats(self): def print_stats(self):
print(f"Dataset statistics {'-' * 15}") print(f"Dataset statistics\n{'-' * 15}")
tr = 0 tr = 0
va = 0 va = 0
te = 0 te = 0
@ -53,11 +53,12 @@ class SimpleGfunDataset:
va += n_va va += n_va
te += n_te te += n_te
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}") print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
print(f"Total {'-' * 15}") print(f"{'-' * 15}\nTotal\n{'-' * 15}")
print(f"tr: {tr} - va: {va} - te: {te}") print(f"tr: {tr} - va: {va} - te: {te}")
def load_csv_inference(self): def load_csv_inference(self):
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv")) # test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
test = pd.read_csv(self.datadir)
self._set_labels(test) self._set_labels(test)
self._set_langs(train=None, test=test) self._set_langs(train=None, test=test)
self.train_data = None self.train_data = None
@ -137,7 +138,10 @@ class SimpleGfunDataset:
def test(self, mask_number=False, target_as_csr=False): def test(self, mask_number=False, target_as_csr=False):
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
lXte = { lXte = {
lang: {"text": apply_mask(self.test_data[lang].text.tolist())} lang: {
"text": apply_mask(self.test_data[lang].text.tolist()),
"id": self.test_data[lang].id.tolist()
}
for lang in self.te_langs for lang in self.te_langs
} }
if not self.only_inference: if not self.only_inference:

View File

@ -17,7 +17,7 @@ def load_from_pickle(path, dataset_name, nrows):
def get_dataset(datasetp_path, args): def get_dataset(datasetp_path, args):
dataset = SimpleGfunDataset( dataset = SimpleGfunDataset(
dataset_name="rai", dataset_name="rai",
datadir=datasetp_path, datapath=datasetp_path,
textual=True, textual=True,
visual=False, visual=False,
multilabel=False, multilabel=False,

View File

@ -240,8 +240,9 @@ class GeneralizedFunnelling:
return self return self
def transform(self, lX): def transform(self, lX, output_ids=False):
projections = [] projections = []
l_ids = {}
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
l_posteriors = vgf.transform(lX) l_posteriors = vgf.transform(lX)
projections.append(l_posteriors) projections.append(l_posteriors)
@ -250,7 +251,11 @@ class GeneralizedFunnelling:
if self.clf_type == "singlelabel": if self.clf_type == "singlelabel":
for lang, preds in l_out.items(): for lang, preds in l_out.items():
l_out[lang] = predict(preds, clf_type=self.clf_type) l_out[lang] = predict(preds, clf_type=self.clf_type)
return l_out l_ids[lang] = lX[lang]["id"]
if output_ids:
return l_out, l_ids
else:
return l_out
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX) return self.fit(lX, lY).transform(lX)

View File

@ -104,7 +104,7 @@ class TfidfVectorizerMultilingual:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
def update(self, X, lang): def update_vectorizer(self, X, lang):
self.langs.append(lang) self.langs.append(lang)
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"]) self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
return self return self
@ -121,8 +121,9 @@ class TfidfVectorizerMultilingual:
for in_l in in_langs: for in_l in in_langs:
if in_l not in self.langs: if in_l not in self.langs:
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]") print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
self.update(X=lX[in_l], lang=in_l) self.update_vectorizer(X=lX[in_l], lang=in_l)
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here! # return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in in_langs}
def fit_transform(self, lX, ly=None): def fit_transform(self, lX, ly=None):
return self.fit(lX, ly).transform(lX) return self.fit(lX, ly).transform(lX)

View File

@ -55,9 +55,11 @@ class MultilingualGen(ViewGen):
return self return self
def transform(self, lX): def transform(self, lX):
_langs = lX.keys()
lX = self.vectorizer.transform(lX) lX = self.vectorizer.transform(lX)
if self.langs != sorted(self.vectorizer.vectorizer.keys()): if _langs != sorted(self.vectorizer.vectorizer.keys()):
# new_langs = set(self.vectorizer.vectorizer.keys()) - set(self.langs) """Loading word-embeddings for unseen languages at training time (zero-shot scenario),
excluding (exclude=old_langs) already loaded matrices."""
old_langs = self.langs old_langs = self.langs
self.langs = sorted(self.vectorizer.vectorizer.keys()) self.langs = sorted(self.vectorizer.vectorizer.keys())
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs) new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
@ -66,14 +68,14 @@ class MultilingualGen(ViewGen):
XdotMulti = Parallel(n_jobs=self.n_jobs)( XdotMulti = Parallel(n_jobs=self.n_jobs)(
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif) delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
for lang in self.langs for lang in _langs
) )
lZ = {lang: XdotMulti[i] for i, lang in enumerate(self.langs)} lZ = {lang: XdotMulti[i] for i, lang in enumerate(_langs)}
lZ = _normalize(lZ, l2=True) lZ = _normalize(lZ, l2=True)
if self.probabilistic and self.fitted: if self.probabilistic and self.fitted:
lZ = self.feature2posterior_projector.transform(lZ) lZ = self.feature2posterior_projector.transform(lZ)
return lZ return lZ
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX) return self.fit(lX, lY).transform(lX)

View File

@ -1,20 +1,21 @@
import os
import pandas as pd
from pathlib import Path
from dataManager.gFunDataset import SimpleGfunDataset from dataManager.gFunDataset import SimpleGfunDataset
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
from main import save_preds
DATASET_NAME = "inference-dataset" # can be anything
def main(args): def main(args):
dataset = SimpleGfunDataset( dataset = SimpleGfunDataset(
dataset_name=DATASET_NAME, dataset_name=Path(args.datapath).stem,
datadir=args.datadir, datapath=args.datapath,
multilabel=False, multilabel=False,
only_inference=True, only_inference=True,
labels=args.nlabels, labels=args.nlabels,
) )
lX, _ = dataset.test() lX, _ = dataset.test()
print("Ok")
gfun = GeneralizedFunnelling( gfun = GeneralizedFunnelling(
dataset_name=dataset.get_name(), dataset_name=dataset.get_name(),
@ -29,17 +30,71 @@ def main(args):
load_meta=True, load_meta=True,
) )
predictions = gfun.transform(lX) predictions, ids = gfun.transform(lX, output_ids=True)
save_preds( save_inference_preds(
preds=predictions, preds=predictions,
doc_ids=ids,
dataset_name=dataset.get_name(),
targets=None, targets=None,
config=f"Inference dataset: {dataset.get_name()}" category_mapper="models/category_mappers/rai-mapping.csv"
) )
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"):
"""
Parameters
----------
preds : Dict[str: np.array]
Predictions produced by generalized-funnelling.
dataset_name: str
Dataset name used as output file name. File is stored in directory defined by `output_dir`
argument e.g. "<output_dir>/<dataset_name>.csv"
doc_ids: Dict[str: List[str]]
Dictionary storing list of document ids (as defined in the csv input file)
targets: Dict[str: np.array]
If availabel, target true labels will be written to output file to ease performance evaluation.
(default = None)
category_mapper: Path
Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer).
If not None, gFun predictions will be converetd and stored as target string classes.
(default=None)
output_dir: Path
Dir where to store output csv file.
(default = results/inference-preds)
"""
os.makedirs(output_dir, exist_ok=True)
df = pd.DataFrame()
langs = sorted(preds.keys())
_ids = []
_preds = []
_targets = []
_langs = []
for lang in langs:
_preds.extend(preds[lang].argmax(axis=1).tolist())
_langs.extend([lang for i in range(len(preds[lang]))])
_ids.extend(doc_ids[lang])
df["doc_id"] = _ids
df["document_language"] = _langs
df["gfun_prediction"] = _preds
if targets is not None:
for lang in langs:
_targets.extend(targets[lang].argmax(axis=1).tolist())
df["document_true_label"] = _targets
if category_mapper is not None:
mapper = pd.read_csv(category_mapper).to_dict()["category"]
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
output_file = f"{output_dir}/{dataset_name}.csv"
print(f"Storing predicitons in: {output_file}")
df.to_csv(output_file, index=False)
if __name__ == "__main__": if __name__ == "__main__":
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!") parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified")
parser.add_argument("--nlabels", type=int, default=28) parser.add_argument("--nlabels", type=int, default=28)
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings") parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance") parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")