handling new data
This commit is contained in:
parent
1a1c48e136
commit
86fbd90bd4
|
|
@ -27,8 +27,8 @@ class gFunDataset:
|
|||
self._load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
mlb = f"Labels are already binarized for {self.dataset_name} dataset"
|
||||
elif self.is_multilabel:
|
||||
mlb = MultiLabelBinarizer()
|
||||
mlb.fit([labels])
|
||||
|
|
@ -85,8 +85,16 @@ class gFunDataset:
|
|||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
self.show_dimension()
|
||||
|
||||
elif "rai" in self.dataset_dir.lower():
|
||||
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rai"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
|
||||
nrows=self.nrows)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.show_dimension()
|
||||
return
|
||||
|
||||
def show_dimension(self):
|
||||
|
|
@ -95,12 +103,17 @@ class gFunDataset:
|
|||
print(
|
||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||
)
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
print(f"-- Labels: {self.labels}")
|
||||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||
if "csv" in dataset_dir:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
|
||||
path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
|
||||
path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
||||
else:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
if nrows is not None:
|
||||
if dataset_name == "cls":
|
||||
|
|
@ -154,7 +167,7 @@ class gFunDataset:
|
|||
return dataset, labels, data_langs
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
# labels are already binarized for rcv1-2 dataset
|
||||
return labels
|
||||
if hasattr(self, "mlb"):
|
||||
|
|
@ -192,7 +205,7 @@ class gFunDataset:
|
|||
return self.data_langs
|
||||
|
||||
def num_labels(self):
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
return len(self.labels)
|
||||
else:
|
||||
return self.labels
|
||||
|
|
@ -219,8 +232,6 @@ if __name__ == "__main__":
|
|||
|
||||
print("Hello gFunDataset")
|
||||
dataset = gFunDataset(
|
||||
# dataset_dir=GLAMI_DATAPATH,
|
||||
# dataset_dir=RCV_DATAPATH,
|
||||
dataset_dir=JRC_DATAPATH,
|
||||
data_langs=None,
|
||||
is_textual=True,
|
||||
|
|
|
|||
|
|
@ -222,6 +222,41 @@ class MultilingualDataset:
|
|||
new_data.append((docs[:maxn], labels[:maxn], None))
|
||||
return new_data
|
||||
|
||||
def from_csv(self, path_tr, path_te):
|
||||
import pandas as pd
|
||||
from os.path import expanduser
|
||||
train = pd.read_csv(expanduser(path_tr))
|
||||
test = pd.read_csv(expanduser(path_te))
|
||||
all_labels = set(train.label.to_list()).union(set(test.label.to_list()))
|
||||
for lang in train.lang.unique():
|
||||
tr_datalang = train.loc[train["lang"] == lang]
|
||||
Xtr = tr_datalang.text.to_list()
|
||||
tr_labels = tr_datalang.label.to_list()
|
||||
# Ytr = np.zeros((len(Xtr), len(all_labels)), dtype=int)
|
||||
Ytr = np.zeros((len(Xtr), 28), dtype=int)
|
||||
for j, i in enumerate(tr_labels):
|
||||
Ytr[j, i] = 1
|
||||
tr_ids = tr_datalang.id.to_list()
|
||||
te_datalang = test.loc[test["lang"] == lang]
|
||||
Xte = te_datalang.text.to_list()
|
||||
te_labels = te_datalang.label.to_list()
|
||||
# Yte = np.zeros((len(Xte), len(all_labels)), dtype=int)
|
||||
Yte = np.zeros((len(Xte), 28), dtype=int)
|
||||
for j, i in enumerate(te_labels):
|
||||
Yte[j, i] = 1
|
||||
te_ids = te_datalang.id.to_list()
|
||||
self.add(
|
||||
lang=lang,
|
||||
Xtr=Xtr,
|
||||
Ytr=Ytr,
|
||||
Xte=Xte,
|
||||
Yte=Yte,
|
||||
tr_ids=tr_ids,
|
||||
te_ids=te_ids
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
|
||||
def _mask_numbers(data):
|
||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||
|
|
@ -240,7 +275,6 @@ def _mask_numbers(data):
|
|||
masked.append(text.replace(".", "").replace(",", "").strip())
|
||||
return masked
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ def get_dataset(dataset_name, args):
|
|||
"glami",
|
||||
"cls",
|
||||
"webis",
|
||||
"rai",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
|
|
@ -42,6 +43,8 @@ def get_dataset(dataset_name, args):
|
|||
"~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||
)
|
||||
|
||||
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
|
||||
|
||||
if dataset_name == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
|
|
@ -87,7 +90,6 @@ def get_dataset(dataset_name, args):
|
|||
)
|
||||
|
||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
|
|
@ -96,7 +98,6 @@ def get_dataset(dataset_name, args):
|
|||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
|
||||
elif dataset_name == "webis":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=WEBIS_CLS,
|
||||
|
|
@ -105,6 +106,14 @@ def get_dataset(dataset_name, args):
|
|||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "rai":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=RAI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
|
|
|||
Loading…
Reference in New Issue