From 77227bbe1392f5436901d95ed8805c5d213224d9 Mon Sep 17 00:00:00 2001 From: Andrea Pedrotti Date: Mon, 6 Mar 2023 11:59:47 +0100 Subject: [PATCH] support for binary dataset; CLS dataset; updated gitignore --- .gitignore | 3 ++- dataManager/clsDataset.py | 11 ++++++++--- dataManager/gFunDataset.py | 2 +- gfun/generalizedFunnelling.py | 4 ++++ main.py | 3 ++- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 25278f6..59aee03 100644 --- a/.gitignore +++ b/.gitignore @@ -180,4 +180,5 @@ amazon_cateogories.bu.txt models/* scripts/ logger/* -explore_data.ipynb \ No newline at end of file +explore_data.ipynb +run.sh \ No newline at end of file diff --git a/dataManager/clsDataset.py b/dataManager/clsDataset.py index 517b395..e81d126 100644 --- a/dataManager/clsDataset.py +++ b/dataManager/clsDataset.py @@ -53,7 +53,8 @@ def process_data(line): # TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig") result = re.sub(regex, subst, line, 0, re.MULTILINE) text, label = result.split("#label#:") - label = 0 if label == "negative" else 1 + # label = 0 if label == "negative" else 1 + label = [1, 0] if label == "negative" else [0, 1] return text, label @@ -64,9 +65,13 @@ if __name__ == "__main__": for lang in LANGS: # TODO: just using book domain atm Xtr = [text[0] for text in data[lang]["books"]["train"]] - Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) + # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) + Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]]) + Xte = [text[0] for text in data[lang]["books"]["test"]] - Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) + # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) + Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]]) + multilingualDataset.add( lang=lang, Xtr=Xtr, diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index 0bbf4c9..a0040ec 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -200,7 +200,7 @@ class gFunDataset: return self.data_langs def num_labels(self): - if self.dataset_name not in ["rcv1-2", "jrc"]: + if self.dataset_name not in ["rcv1-2", "jrc", "cls"]: return len(self.labels) else: return self.labels diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 3600999..52f57a3 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -251,6 +251,10 @@ class GeneralizedFunnelling: projections.append(l_posteriors) agg = self.aggregate(projections) l_out = self.metaclassifier.predict_proba(agg) + # converting to binary predictions + # if self.dataset_name in ["cls"]: # TODO: better way to do this + # for lang, preds in l_out.items(): + # l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1) return l_out def fit_transform(self, lX, lY): diff --git a/main.py b/main.py index a21a364..7198356 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - - [!] add support for Binary Datasets (e.g. cls) + - [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data + - [!] documents should be trimmed to the same length (?) - [!] logging - add documentations sphinx - [!] zero-shot setup