handling new data
This commit is contained in:
parent
1a1c48e136
commit
86fbd90bd4
|
|
@ -27,8 +27,8 @@ class gFunDataset:
|
||||||
self._load_dataset()
|
self._load_dataset()
|
||||||
|
|
||||||
def get_label_binarizer(self, labels):
|
def get_label_binarizer(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
mlb = f"Labels are already binarized for {self.dataset_name} dataset"
|
||||||
elif self.is_multilabel:
|
elif self.is_multilabel:
|
||||||
mlb = MultiLabelBinarizer()
|
mlb = MultiLabelBinarizer()
|
||||||
mlb.fit([labels])
|
mlb.fit([labels])
|
||||||
|
|
@ -85,8 +85,16 @@ class gFunDataset:
|
||||||
self.dataset_name, self.dataset_dir, self.nrows
|
self.dataset_name, self.dataset_dir, self.nrows
|
||||||
)
|
)
|
||||||
self.mlb = self.get_label_binarizer(self.labels)
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
self.show_dimension()
|
|
||||||
|
|
||||||
|
elif "rai" in self.dataset_dir.lower():
|
||||||
|
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
|
||||||
|
self.dataset_name = "rai"
|
||||||
|
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||||
|
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
|
||||||
|
nrows=self.nrows)
|
||||||
|
self.mlb = self.get_label_binarizer(self.labels)
|
||||||
|
|
||||||
|
self.show_dimension()
|
||||||
return
|
return
|
||||||
|
|
||||||
def show_dimension(self):
|
def show_dimension(self):
|
||||||
|
|
@ -95,12 +103,17 @@ class gFunDataset:
|
||||||
print(
|
print(
|
||||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||||
)
|
)
|
||||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||||
print(f"-- Labels: {self.labels}")
|
print(f"-- Labels: {self.labels}")
|
||||||
else:
|
else:
|
||||||
print(f"-- Labels: {len(self.labels)}")
|
print(f"-- Labels: {len(self.labels)}")
|
||||||
|
|
||||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||||
|
if "csv" in dataset_dir:
|
||||||
|
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
|
||||||
|
path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
|
||||||
|
path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
||||||
|
else:
|
||||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||||
if nrows is not None:
|
if nrows is not None:
|
||||||
if dataset_name == "cls":
|
if dataset_name == "cls":
|
||||||
|
|
@ -154,7 +167,7 @@ class gFunDataset:
|
||||||
return dataset, labels, data_langs
|
return dataset, labels, data_langs
|
||||||
|
|
||||||
def binarize_labels(self, labels):
|
def binarize_labels(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||||
# labels are already binarized for rcv1-2 dataset
|
# labels are already binarized for rcv1-2 dataset
|
||||||
return labels
|
return labels
|
||||||
if hasattr(self, "mlb"):
|
if hasattr(self, "mlb"):
|
||||||
|
|
@ -192,7 +205,7 @@ class gFunDataset:
|
||||||
return self.data_langs
|
return self.data_langs
|
||||||
|
|
||||||
def num_labels(self):
|
def num_labels(self):
|
||||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||||
return len(self.labels)
|
return len(self.labels)
|
||||||
else:
|
else:
|
||||||
return self.labels
|
return self.labels
|
||||||
|
|
@ -219,8 +232,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print("Hello gFunDataset")
|
print("Hello gFunDataset")
|
||||||
dataset = gFunDataset(
|
dataset = gFunDataset(
|
||||||
# dataset_dir=GLAMI_DATAPATH,
|
|
||||||
# dataset_dir=RCV_DATAPATH,
|
|
||||||
dataset_dir=JRC_DATAPATH,
|
dataset_dir=JRC_DATAPATH,
|
||||||
data_langs=None,
|
data_langs=None,
|
||||||
is_textual=True,
|
is_textual=True,
|
||||||
|
|
|
||||||
|
|
@ -222,6 +222,41 @@ class MultilingualDataset:
|
||||||
new_data.append((docs[:maxn], labels[:maxn], None))
|
new_data.append((docs[:maxn], labels[:maxn], None))
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
def from_csv(self, path_tr, path_te):
|
||||||
|
import pandas as pd
|
||||||
|
from os.path import expanduser
|
||||||
|
train = pd.read_csv(expanduser(path_tr))
|
||||||
|
test = pd.read_csv(expanduser(path_te))
|
||||||
|
all_labels = set(train.label.to_list()).union(set(test.label.to_list()))
|
||||||
|
for lang in train.lang.unique():
|
||||||
|
tr_datalang = train.loc[train["lang"] == lang]
|
||||||
|
Xtr = tr_datalang.text.to_list()
|
||||||
|
tr_labels = tr_datalang.label.to_list()
|
||||||
|
# Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int)
|
||||||
|
Ytr = np.zeros((len(Xtr), 28), dtype=int)
|
||||||
|
for j, i in enumerate(tr_labels):
|
||||||
|
Ytr[j, i] = 1
|
||||||
|
tr_ids = tr_datalang.id.to_list()
|
||||||
|
te_datalang = test.loc[test["lang"] == lang]
|
||||||
|
Xte = te_datalang.text.to_list()
|
||||||
|
te_labels = te_datalang.label.to_list()
|
||||||
|
# Yte = np.zeros((len(Xte), len(all_labels)), dtype=int)
|
||||||
|
Yte = np.zeros((len(Xte), 28), dtype=int)
|
||||||
|
for j, i in enumerate(te_labels):
|
||||||
|
Yte[j, i] = 1
|
||||||
|
te_ids = te_datalang.id.to_list()
|
||||||
|
self.add(
|
||||||
|
lang=lang,
|
||||||
|
Xtr=Xtr,
|
||||||
|
Ytr=Ytr,
|
||||||
|
Xte=Xte,
|
||||||
|
Yte=Yte,
|
||||||
|
tr_ids=tr_ids,
|
||||||
|
te_ids=te_ids
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _mask_numbers(data):
|
def _mask_numbers(data):
|
||||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||||
|
|
@ -240,7 +275,6 @@ def _mask_numbers(data):
|
||||||
masked.append(text.replace(".", "").replace(",", "").strip())
|
masked.append(text.replace(".", "").replace(",", "").strip())
|
||||||
return masked
|
return masked
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
DATAPATH = expanduser(
|
DATAPATH = expanduser(
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ def get_dataset(dataset_name, args):
|
||||||
"glami",
|
"glami",
|
||||||
"cls",
|
"cls",
|
||||||
"webis",
|
"webis",
|
||||||
|
"rai",
|
||||||
], "dataset not supported"
|
], "dataset not supported"
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
RCV_DATAPATH = expanduser(
|
||||||
|
|
@ -42,6 +43,8 @@ def get_dataset(dataset_name, args):
|
||||||
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
|
||||||
|
|
||||||
if dataset_name == "multinews":
|
if dataset_name == "multinews":
|
||||||
# TODO: convert to gFunDataset
|
# TODO: convert to gFunDataset
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -87,7 +90,6 @@ def get_dataset(dataset_name, args):
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||||
|
|
||||||
elif dataset_name == "cls":
|
elif dataset_name == "cls":
|
||||||
dataset = gFunDataset(
|
dataset = gFunDataset(
|
||||||
dataset_dir=CLS_DATAPATH,
|
dataset_dir=CLS_DATAPATH,
|
||||||
|
|
@ -96,7 +98,6 @@ def get_dataset(dataset_name, args):
|
||||||
is_multilabel=False,
|
is_multilabel=False,
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif dataset_name == "webis":
|
elif dataset_name == "webis":
|
||||||
dataset = gFunDataset(
|
dataset = gFunDataset(
|
||||||
dataset_dir=WEBIS_CLS,
|
dataset_dir=WEBIS_CLS,
|
||||||
|
|
@ -105,6 +106,14 @@ def get_dataset(dataset_name, args):
|
||||||
is_multilabel=False,
|
is_multilabel=False,
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
|
elif dataset_name == "rai":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=RAI_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue