support for binary dataset; CLS dataset; updated gitignore

This commit is contained in:
Andrea Pedrotti 2023-03-06 11:59:47 +01:00
parent f9d4e50297
commit 77227bbe13
5 changed files with 17 additions and 6 deletions

1
.gitignore vendored
View File

@ -181,3 +181,4 @@ models/*
scripts/ scripts/
logger/* logger/*
explore_data.ipynb explore_data.ipynb
run.sh

View File

@ -53,7 +53,8 @@ def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig") # TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE) result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:") text, label = result.split("#label#:")
label = 0 if label == "negative" else 1 # label = 0 if label == "negative" else 1
label = [1, 0] if label == "negative" else [0, 1]
return text, label return text, label
@ -64,9 +65,13 @@ if __name__ == "__main__":
for lang in LANGS: for lang in LANGS:
# TODO: just using book domain atm # TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]] Xtr = [text[0] for text in data[lang]["books"]["train"]]
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1) # Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
Xte = [text[0] for text in data[lang]["books"]["test"]] Xte = [text[0] for text in data[lang]["books"]["test"]]
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1) # Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
multilingualDataset.add( multilingualDataset.add(
lang=lang, lang=lang,
Xtr=Xtr, Xtr=Xtr,

View File

@ -200,7 +200,7 @@ class gFunDataset:
return self.data_langs return self.data_langs
def num_labels(self): def num_labels(self):
if self.dataset_name not in ["rcv1-2", "jrc"]: if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
return len(self.labels) return len(self.labels)
else: else:
return self.labels return self.labels

View File

@ -251,6 +251,10 @@ class GeneralizedFunnelling:
projections.append(l_posteriors) projections.append(l_posteriors)
agg = self.aggregate(projections) agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg) l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions
# if self.dataset_name in ["cls"]: # TODO: better way to do this
# for lang, preds in l_out.items():
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
return l_out return l_out
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):

View File

@ -13,7 +13,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- [!] add support for Binary Datasets (e.g. cls) - [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] logging - [!] logging
- add documentations sphinx - add documentations sphinx
- [!] zero-shot setup - [!] zero-shot setup