fixed inference script
This commit is contained in:
parent
4615bc3857
commit
35cc32e541
|
|
@ -10,16 +10,16 @@ from tqdm import tqdm
|
|||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
# from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
# from dataManager.glamiDataset import get_dataframe
|
||||
# from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
||||
|
||||
class SimpleGfunDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name=None,
|
||||
datadir="~/datasets/rai/csv/",
|
||||
datapath="~/datasets/rai/csv/test.csv",
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
|
|
@ -30,7 +30,7 @@ class SimpleGfunDataset:
|
|||
labels=None
|
||||
):
|
||||
self.name = dataset_name
|
||||
self.datadir = os.path.expanduser(datadir)
|
||||
self.datadir = os.path.expanduser(datapath)
|
||||
self.textual = textual
|
||||
self.visual = visual
|
||||
self.multilabel = multilabel
|
||||
|
|
@ -41,7 +41,7 @@ class SimpleGfunDataset:
|
|||
self.print_stats()
|
||||
|
||||
def print_stats(self):
|
||||
print(f"Dataset statistics {'-' * 15}")
|
||||
print(f"Dataset statistics\n{'-' * 15}")
|
||||
tr = 0
|
||||
va = 0
|
||||
te = 0
|
||||
|
|
@ -53,11 +53,12 @@ class SimpleGfunDataset:
|
|||
va += n_va
|
||||
te += n_te
|
||||
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
|
||||
print(f"Total {'-' * 15}")
|
||||
print(f"{'-' * 15}\nTotal\n{'-' * 15}")
|
||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||
|
||||
def load_csv_inference(self):
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||
test = pd.read_csv(self.datadir)
|
||||
self._set_labels(test)
|
||||
self._set_langs(train=None, test=test)
|
||||
self.train_data = None
|
||||
|
|
@ -137,7 +138,10 @@ class SimpleGfunDataset:
|
|||
def test(self, mask_number=False, target_as_csr=False):
|
||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||
lXte = {
|
||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||
lang: {
|
||||
"text": apply_mask(self.test_data[lang].text.tolist()),
|
||||
"id": self.test_data[lang].id.tolist()
|
||||
}
|
||||
for lang in self.te_langs
|
||||
}
|
||||
if not self.only_inference:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def load_from_pickle(path, dataset_name, nrows):
|
|||
def get_dataset(datasetp_path, args):
|
||||
dataset = SimpleGfunDataset(
|
||||
dataset_name="rai",
|
||||
datadir=datasetp_path,
|
||||
datapath=datasetp_path,
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
|
|
|
|||
|
|
@ -240,8 +240,9 @@ class GeneralizedFunnelling:
|
|||
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
def transform(self, lX, output_ids=False):
|
||||
projections = []
|
||||
l_ids = {}
|
||||
for vgf in self.first_tier_learners:
|
||||
l_posteriors = vgf.transform(lX)
|
||||
projections.append(l_posteriors)
|
||||
|
|
@ -250,7 +251,11 @@ class GeneralizedFunnelling:
|
|||
if self.clf_type == "singlelabel":
|
||||
for lang, preds in l_out.items():
|
||||
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||
return l_out
|
||||
l_ids[lang] = lX[lang]["id"]
|
||||
if output_ids:
|
||||
return l_out, l_ids
|
||||
else:
|
||||
return l_out
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
return self.fit(lX, lY).transform(lX)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class TfidfVectorizerMultilingual:
|
|||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def update(self, X, lang):
|
||||
def update_vectorizer(self, X, lang):
|
||||
self.langs.append(lang)
|
||||
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
|
||||
return self
|
||||
|
|
@ -121,8 +121,9 @@ class TfidfVectorizerMultilingual:
|
|||
for in_l in in_langs:
|
||||
if in_l not in self.langs:
|
||||
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
|
||||
self.update(X=lX[in_l], lang=in_l)
|
||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
|
||||
self.update_vectorizer(X=lX[in_l], lang=in_l)
|
||||
# return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
|
||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in in_langs}
|
||||
|
||||
def fit_transform(self, lX, ly=None):
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
|
|
|||
|
|
@ -55,9 +55,11 @@ class MultilingualGen(ViewGen):
|
|||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
_langs = lX.keys()
|
||||
lX = self.vectorizer.transform(lX)
|
||||
if self.langs != sorted(self.vectorizer.vectorizer.keys()):
|
||||
# new_langs = set(self.vectorizer.vectorizer.keys()) - set(self.langs)
|
||||
if _langs != sorted(self.vectorizer.vectorizer.keys()):
|
||||
"""Loading word-embeddings for unseen languages at training time (zero-shot scenario),
|
||||
excluding (exclude=old_langs) already loaded matrices."""
|
||||
old_langs = self.langs
|
||||
self.langs = sorted(self.vectorizer.vectorizer.keys())
|
||||
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
|
||||
|
|
@ -66,9 +68,9 @@ class MultilingualGen(ViewGen):
|
|||
|
||||
XdotMulti = Parallel(n_jobs=self.n_jobs)(
|
||||
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
|
||||
for lang in self.langs
|
||||
for lang in _langs
|
||||
)
|
||||
lZ = {lang: XdotMulti[i] for i, lang in enumerate(self.langs)}
|
||||
lZ = {lang: XdotMulti[i] for i, lang in enumerate(_langs)}
|
||||
lZ = _normalize(lZ, l2=True)
|
||||
if self.probabilistic and self.fitted:
|
||||
lZ = self.feature2posterior_projector.transform(lZ)
|
||||
|
|
|
|||
73
infer.py
73
infer.py
|
|
@ -1,20 +1,21 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from dataManager.gFunDataset import SimpleGfunDataset
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
from main import save_preds
|
||||
|
||||
DATASET_NAME = "inference-dataset" # can be anything
|
||||
|
||||
def main(args):
|
||||
dataset = SimpleGfunDataset(
|
||||
dataset_name=DATASET_NAME,
|
||||
datadir=args.datadir,
|
||||
dataset_name=Path(args.datapath).stem,
|
||||
datapath=args.datapath,
|
||||
multilabel=False,
|
||||
only_inference=True,
|
||||
labels=args.nlabels,
|
||||
)
|
||||
|
||||
lX, _ = dataset.test()
|
||||
print("Ok")
|
||||
|
||||
gfun = GeneralizedFunnelling(
|
||||
dataset_name=dataset.get_name(),
|
||||
|
|
@ -29,17 +30,71 @@ def main(args):
|
|||
load_meta=True,
|
||||
)
|
||||
|
||||
predictions = gfun.transform(lX)
|
||||
save_preds(
|
||||
predictions, ids = gfun.transform(lX, output_ids=True)
|
||||
save_inference_preds(
|
||||
preds=predictions,
|
||||
doc_ids=ids,
|
||||
dataset_name=dataset.get_name(),
|
||||
targets=None,
|
||||
config=f"Inference dataset: {dataset.get_name()}"
|
||||
category_mapper="models/category_mappers/rai-mapping.csv"
|
||||
)
|
||||
|
||||
def save_inference_preds(preds, dataset_name, doc_ids, targets=None, category_mapper=None, output_dir="results/inference-preds"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
preds : Dict[str: np.array]
|
||||
Predictions produced by generalized-funnelling.
|
||||
dataset_name: str
|
||||
Dataset name used as output file name. File is stored in directory defined by `output_dir`
|
||||
argument e.g. "<output_dir>/<dataset_name>.csv"
|
||||
doc_ids: Dict[str: List[str]]
|
||||
Dictionary storing list of document ids (as defined in the csv input file)
|
||||
targets: Dict[str: np.array]
|
||||
If availabel, target true labels will be written to output file to ease performance evaluation.
|
||||
(default = None)
|
||||
category_mapper: Path
|
||||
Path to the 'category_mapper' csv file storing the category names (str) for each target class (integer).
|
||||
If not None, gFun predictions will be converetd and stored as target string classes.
|
||||
(default=None)
|
||||
output_dir: Path
|
||||
Dir where to store output csv file.
|
||||
(default = results/inference-preds)
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
df = pd.DataFrame()
|
||||
langs = sorted(preds.keys())
|
||||
_ids = []
|
||||
_preds = []
|
||||
_targets = []
|
||||
_langs = []
|
||||
for lang in langs:
|
||||
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
||||
_langs.extend([lang for i in range(len(preds[lang]))])
|
||||
_ids.extend(doc_ids[lang])
|
||||
|
||||
df["doc_id"] = _ids
|
||||
df["document_language"] = _langs
|
||||
df["gfun_prediction"] = _preds
|
||||
|
||||
if targets is not None:
|
||||
for lang in langs:
|
||||
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||
df["document_true_label"] = _targets
|
||||
|
||||
if category_mapper is not None:
|
||||
mapper = pd.read_csv(category_mapper).to_dict()["category"]
|
||||
df["gfun_string_prediction"] = [mapper[p] for p in _preds]
|
||||
|
||||
output_file = f"{output_dir}/{dataset_name}.csv"
|
||||
print(f"Storing predicitons in: {output_file}")
|
||||
df.to_csv(output_file, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--datadir", type=str, default="~/datasets/rai/csv", help="directory should contain train.csv and test.csv. Train.csv is not required at inferece time!")
|
||||
parser.add_argument("--datapath", type=str, default="~/datasets/rai/csv", help="Path to csv file containing the documents to be classified")
|
||||
parser.add_argument("--nlabels", type=int, default=28)
|
||||
parser.add_argument("--muse_dir", type=str, default="~/resources/muse_embeddings", help="Path to muse embeddings")
|
||||
parser.add_argument("--trained_gfun", type=str, default="rai_pmt_mean_231029", help="name of the trained gfun instance")
|
||||
|
|
|
|||
Loading…
Reference in New Issue