support for binary dataset; CLS dataset; updated gitignore

This commit is contained in:
Andrea Pedrotti 2023-03-06 11:59:47 +01:00
parent f9d4e50297
commit 77227bbe13
5 changed files with 17 additions and 6 deletions

1
.gitignore vendored
View File

@ -181,3 +181,4 @@ models/*
scripts/
logger/*
explore_data.ipynb
run.sh

View File

@ -53,7 +53,8 @@ def process_data(line):
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
result = re.sub(regex, subst, line, 0, re.MULTILINE)
text, label = result.split("#label#:")
label = 0 if label == "negative" else 1
# label = 0 if label == "negative" else 1
label = [1, 0] if label == "negative" else [0, 1]
return text, label
@ -64,9 +65,13 @@ if __name__ == "__main__":
for lang in LANGS:
# TODO: just using book domain atm
Xtr = [text[0] for text in data[lang]["books"]["train"]]
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
Xte = [text[0] for text in data[lang]["books"]["test"]]
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
multilingualDataset.add(
lang=lang,
Xtr=Xtr,

View File

@ -200,7 +200,7 @@ class gFunDataset:
return self.data_langs
def num_labels(self):
if self.dataset_name not in ["rcv1-2", "jrc"]:
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
return len(self.labels)
else:
return self.labels

View File

@ -251,6 +251,10 @@ class GeneralizedFunnelling:
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions
# if self.dataset_name in ["cls"]: # TODO: better way to do this
# for lang, preds in l_out.items():
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
return l_out
def fit_transform(self, lX, lY):

View File

@ -13,7 +13,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls)
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] logging
- add documentations sphinx
- [!] zero-shot setup