support for binary dataset; CLS dataset; updated gitignore
This commit is contained in:
parent
f9d4e50297
commit
77227bbe13
|
@ -181,3 +181,4 @@ models/*
|
||||||
scripts/
|
scripts/
|
||||||
logger/*
|
logger/*
|
||||||
explore_data.ipynb
|
explore_data.ipynb
|
||||||
|
run.sh
|
|
@ -53,7 +53,8 @@ def process_data(line):
|
||||||
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
|
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
|
||||||
result = re.sub(regex, subst, line, 0, re.MULTILINE)
|
result = re.sub(regex, subst, line, 0, re.MULTILINE)
|
||||||
text, label = result.split("#label#:")
|
text, label = result.split("#label#:")
|
||||||
label = 0 if label == "negative" else 1
|
# label = 0 if label == "negative" else 1
|
||||||
|
label = [1, 0] if label == "negative" else [0, 1]
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,9 +65,13 @@ if __name__ == "__main__":
|
||||||
for lang in LANGS:
|
for lang in LANGS:
|
||||||
# TODO: just using book domain atm
|
# TODO: just using book domain atm
|
||||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||||
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||||
|
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
||||||
|
|
||||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||||
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||||
|
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
||||||
|
|
||||||
multilingualDataset.add(
|
multilingualDataset.add(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
Xtr=Xtr,
|
Xtr=Xtr,
|
||||||
|
|
|
@ -200,7 +200,7 @@ class gFunDataset:
|
||||||
return self.data_langs
|
return self.data_langs
|
||||||
|
|
||||||
def num_labels(self):
|
def num_labels(self):
|
||||||
if self.dataset_name not in ["rcv1-2", "jrc"]:
|
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
||||||
return len(self.labels)
|
return len(self.labels)
|
||||||
else:
|
else:
|
||||||
return self.labels
|
return self.labels
|
||||||
|
|
|
@ -251,6 +251,10 @@ class GeneralizedFunnelling:
|
||||||
projections.append(l_posteriors)
|
projections.append(l_posteriors)
|
||||||
agg = self.aggregate(projections)
|
agg = self.aggregate(projections)
|
||||||
l_out = self.metaclassifier.predict_proba(agg)
|
l_out = self.metaclassifier.predict_proba(agg)
|
||||||
|
# converting to binary predictions
|
||||||
|
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
||||||
|
# for lang, preds in l_out.items():
|
||||||
|
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
||||||
return l_out
|
return l_out
|
||||||
|
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
|
|
3
main.py
3
main.py
|
@ -13,7 +13,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- [!] add support for Binary Datasets (e.g. cls)
|
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
|
||||||
|
- [!] documents should be trimmed to the same length (?)
|
||||||
- [!] logging
|
- [!] logging
|
||||||
- add documentations sphinx
|
- add documentations sphinx
|
||||||
- [!] zero-shot setup
|
- [!] zero-shot setup
|
||||||
|
|
Loading…
Reference in New Issue