handling new data

This commit is contained in:
Andrea Pedrotti 2023-06-29 11:41:22 +02:00
parent 1a1c48e136
commit 86fbd90bd4
3 changed files with 66 additions and 12 deletions

View File

@ -27,8 +27,8 @@ class gFunDataset:
self._load_dataset() self._load_dataset()
def get_label_binarizer(self, labels): def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]: if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
mlb = "Labels are already binarized for rcv1-2 dataset" mlb = f"Labels are already binarized for {self.dataset_name} dataset"
elif self.is_multilabel: elif self.is_multilabel:
mlb = MultiLabelBinarizer() mlb = MultiLabelBinarizer()
mlb.fit([labels]) mlb.fit([labels])
@ -85,8 +85,16 @@ class gFunDataset:
self.dataset_name, self.dataset_dir, self.nrows self.dataset_name, self.dataset_dir, self.nrows
) )
self.mlb = self.get_label_binarizer(self.labels) self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
elif "rai" in self.dataset_dir.lower():
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
self.dataset_name = "rai"
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
nrows=self.nrows)
self.mlb = self.get_label_binarizer(self.labels)
self.show_dimension()
return return
def show_dimension(self): def show_dimension(self):
@ -95,12 +103,17 @@ class gFunDataset:
print( print(
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
) )
if self.dataset_name in ["rcv1-2", "jrc", "cls"]: if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
print(f"-- Labels: {self.labels}") print(f"-- Labels: {self.labels}")
else: else:
print(f"-- Labels: {len(self.labels)}") print(f"-- Labels: {len(self.labels)}")
def _load_multilingual(self, dataset_name, dataset_dir, nrows): def _load_multilingual(self, dataset_name, dataset_dir, nrows):
if "csv" in dataset_dir:
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
else:
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir) old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
if nrows is not None: if nrows is not None:
if dataset_name == "cls": if dataset_name == "cls":
@ -154,7 +167,7 @@ class gFunDataset:
return dataset, labels, data_langs return dataset, labels, data_langs
def binarize_labels(self, labels): def binarize_labels(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]: if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
# labels are already binarized for rcv1-2 dataset # labels are already binarized for rcv1-2 dataset
return labels return labels
if hasattr(self, "mlb"): if hasattr(self, "mlb"):
@ -192,7 +205,7 @@ class gFunDataset:
return self.data_langs return self.data_langs
def num_labels(self): def num_labels(self):
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]: if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
return len(self.labels) return len(self.labels)
else: else:
return self.labels return self.labels
@ -219,8 +232,6 @@ if __name__ == "__main__":
print("Hello gFunDataset") print("Hello gFunDataset")
dataset = gFunDataset( dataset = gFunDataset(
# dataset_dir=GLAMI_DATAPATH,
# dataset_dir=RCV_DATAPATH,
dataset_dir=JRC_DATAPATH, dataset_dir=JRC_DATAPATH,
data_langs=None, data_langs=None,
is_textual=True, is_textual=True,

View File

@ -222,6 +222,41 @@ class MultilingualDataset:
new_data.append((docs[:maxn], labels[:maxn], None)) new_data.append((docs[:maxn], labels[:maxn], None))
return new_data return new_data
def from_csv(self, path_tr, path_te):
import pandas as pd
from os.path import expanduser
train = pd.read_csv(expanduser(path_tr))
test = pd.read_csv(expanduser(path_te))
all_labels = set(train.label.to_list()).union(set(test.label.to_list()))
for lang in train.lang.unique():
tr_datalang = train.loc[train["lang"] == lang]
Xtr = tr_datalang.text.to_list()
tr_labels = tr_datalang.label.to_list()
# Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int)
Ytr = np.zeros((len(Xtr), 28), dtype=int)
for j, i in enumerate(tr_labels):
Ytr[j, i] = 1
tr_ids = tr_datalang.id.to_list()
te_datalang = test.loc[test["lang"] == lang]
Xte = te_datalang.text.to_list()
te_labels = te_datalang.label.to_list()
# Yte = np.zeros((len(Xte), len(all_labels)), dtype=int)
Yte = np.zeros((len(Xte), 28), dtype=int)
for j, i in enumerate(te_labels):
Yte[j, i] = 1
te_ids = te_datalang.id.to_list()
self.add(
lang=lang,
Xtr=Xtr,
Ytr=Ytr,
Xte=Xte,
Yte=Yte,
tr_ids=tr_ids,
te_ids=te_ids
)
return self
def _mask_numbers(data): def _mask_numbers(data):
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
@ -240,7 +275,6 @@ def _mask_numbers(data):
masked.append(text.replace(".", "").replace(",", "").strip()) masked.append(text.replace(".", "").replace(",", "").strip())
return masked return masked
if __name__ == "__main__": if __name__ == "__main__":
DATAPATH = expanduser( DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"

View File

@ -24,6 +24,7 @@ def get_dataset(dataset_name, args):
"glami", "glami",
"cls", "cls",
"webis", "webis",
"rai",
], "dataset not supported" ], "dataset not supported"
RCV_DATAPATH = expanduser( RCV_DATAPATH = expanduser(
@ -42,6 +43,8 @@ def get_dataset(dataset_name, args):
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
) )
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
if dataset_name == "multinews": if dataset_name == "multinews":
# TODO: convert to gFunDataset # TODO: convert to gFunDataset
raise NotImplementedError raise NotImplementedError
@ -87,7 +90,6 @@ def get_dataset(dataset_name, args):
) )
dataset.save_as_pickle(GLAMI_DATAPATH) dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls": elif dataset_name == "cls":
dataset = gFunDataset( dataset = gFunDataset(
dataset_dir=CLS_DATAPATH, dataset_dir=CLS_DATAPATH,
@ -96,7 +98,6 @@ def get_dataset(dataset_name, args):
is_multilabel=False, is_multilabel=False,
nrows=args.nrows, nrows=args.nrows,
) )
elif dataset_name == "webis": elif dataset_name == "webis":
dataset = gFunDataset( dataset = gFunDataset(
dataset_dir=WEBIS_CLS, dataset_dir=WEBIS_CLS,
@ -105,6 +106,14 @@ def get_dataset(dataset_name, args):
is_multilabel=False, is_multilabel=False,
nrows=args.nrows, nrows=args.nrows,
) )
elif dataset_name == "rai":
dataset = gFunDataset(
dataset_dir=RAI_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows
)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset return dataset