support for binary dataset; CLS dataset; updated gitignore
This commit is contained in:
parent
f9d4e50297
commit
77227bbe13
|
@ -180,4 +180,5 @@ amazon_cateogories.bu.txt
|
|||
models/*
|
||||
scripts/
|
||||
logger/*
|
||||
explore_data.ipynb
|
||||
explore_data.ipynb
|
||||
run.sh
|
|
@ -53,7 +53,8 @@ def process_data(line):
|
|||
# TODO: we are adding a space after each pucntuation mark (e.g., ". es ich das , langweilig lustig")
|
||||
result = re.sub(regex, subst, line, 0, re.MULTILINE)
|
||||
text, label = result.split("#label#:")
|
||||
label = 0 if label == "negative" else 1
|
||||
# label = 0 if label == "negative" else 1
|
||||
label = [1, 0] if label == "negative" else [0, 1]
|
||||
return text, label
|
||||
|
||||
|
||||
|
@ -64,9 +65,13 @@ if __name__ == "__main__":
|
|||
for lang in LANGS:
|
||||
# TODO: just using book domain atm
|
||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||
Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
||||
|
||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||
Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
||||
|
||||
multilingualDataset.add(
|
||||
lang=lang,
|
||||
Xtr=Xtr,
|
||||
|
|
|
@ -200,7 +200,7 @@ class gFunDataset:
|
|||
return self.data_langs
|
||||
|
||||
def num_labels(self):
|
||||
if self.dataset_name not in ["rcv1-2", "jrc"]:
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
||||
return len(self.labels)
|
||||
else:
|
||||
return self.labels
|
||||
|
|
|
@ -251,6 +251,10 @@ class GeneralizedFunnelling:
|
|||
projections.append(l_posteriors)
|
||||
agg = self.aggregate(projections)
|
||||
l_out = self.metaclassifier.predict_proba(agg)
|
||||
# converting to binary predictions
|
||||
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
||||
# for lang, preds in l_out.items():
|
||||
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
||||
return l_out
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
|
|
3
main.py
3
main.py
|
@ -13,7 +13,8 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
|
||||
"""
|
||||
TODO:
|
||||
- [!] add support for Binary Datasets (e.g. cls)
|
||||
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
|
||||
- [!] documents should be trimmed to the same length (?)
|
||||
- [!] logging
|
||||
- add documentations sphinx
|
||||
- [!] zero-shot setup
|
||||
|
|
Loading…
Reference in New Issue