diff --git a/dataManager/gFunDataset.py b/dataManager/gFunDataset.py index ab6c3c6..1f38917 100644 --- a/dataManager/gFunDataset.py +++ b/dataManager/gFunDataset.py @@ -24,13 +24,15 @@ class SimpleGfunDataset: visual=False, multilabel=False, set_tr_langs=None, - set_te_langs=None + set_te_langs=None, + reduced=False ): self.name = dataset_name self.datadir = os.path.expanduser(datadir) self.textual = textual self.visual = visual self.multilabel = multilabel + self.reduced = reduced self.load_csv(set_tr_langs, set_te_langs) self.print_stats() @@ -51,7 +53,7 @@ class SimpleGfunDataset: print(f"tr: {tr} - va: {va} - te: {te}") def load_csv(self, set_tr_langs, set_te_langs): - _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv")) + _data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv")) try: stratified = "class" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label) @@ -59,7 +61,7 @@ class SimpleGfunDataset: stratified = "lang" train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) print(f"- dataset stratified by {stratified}") - test = pd.read_csv(os.path.join(self.datadir, "test.small.csv")) + test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv")) self._set_langs (train, test, set_tr_langs, set_te_langs) self._set_labels(_data_tr) self.full_train = _data_tr @@ -140,168 +142,6 @@ class SimpleGfunDataset: return one_hot_matrix -class gFunDataset: - def __init__( - self, - dataset_dir, - is_textual, - is_visual, - is_multilabel, - labels=None, - nrows=None, - data_langs=None, - ): - self.dataset_dir = dataset_dir - self.data_langs = data_langs - self.is_textual = is_textual - self.is_visual = is_visual - self.is_multilabel = is_multilabel - self.labels = labels - self.nrows = nrows - self.dataset = {} - self._load_dataset() - - def get_label_binarizer(self, labels): - if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: - mlb = f"Labels are already binarized for {self.dataset_name} dataset" - elif self.is_multilabel: - mlb = MultiLabelBinarizer() - mlb.fit([labels]) - else: - mlb = LabelBinarizer() - mlb.fit(labels) - return mlb - - def _load_dataset(self): - print(f"- Loading dataset from {self.dataset_dir}") - self.dataset_name = "rai" - self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name, - dataset_dir=self.dataset_dir, - nrows=self.nrows) - self.mlb = self.get_label_binarizer(self.labels) - self.show_dimension() - return - - def show_dimension(self): - print(f"\n[Dataset: {self.dataset_name.upper()}]") - for lang, data in self.dataset.items(): - print( - f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}" - ) - if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: - print(f"-- Labels: {self.labels}") - else: - print(f"-- Labels: {len(self.labels)}") - - def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"): - if "csv" in dataset_dir: - old_dataset = MultilingualDataset(dataset_name="rai").from_csv( - path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")), - path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv")) - ) - if nrows is not None: - if dataset_name == "cls": - old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows) - else: - old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows) - labels = old_dataset.num_labels() - data_langs = old_dataset.langs() - - def _format_multilingual(data): - text = data[0] - image = None - labels = data[1] - return {"text": text, "image": image, "label": labels} - - dataset = { - k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])} - for k, v in old_dataset.multiling_dataset.items() - } - return dataset, labels, data_langs - - def _load_glami(self, dataset_dir, nrows): - train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows) - test_split = get_dataframe("test", dataset_dir=dataset_dir).sample( - n=int(nrows / 10) - ) - - gb_train = train_split.groupby("geo") - gb_test = test_split.groupby("geo") - - if self.data_langs is None: - data_langs = sorted(train_split.geo.unique().tolist()) - if self.labels is None: - labels = train_split.category_name.unique().tolist() - - def _format_glami(data_df): - text = (data_df.name + " " + data_df.description).tolist() - image = data_df.image_file.tolist() - labels = data_df.category_name.tolist() - return {"text": text, "image": image, "label": labels} - - dataset = { - lang: { - "train": _format_glami(data_tr), - "test": _format_glami(gb_test.get_group(lang)), - } - for lang, data_tr in gb_train - if lang in data_langs - } - - return dataset, labels, data_langs - - def binarize_labels(self, labels): - if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]: - # labels are already binarized for rcv1-2 dataset - return labels - if hasattr(self, "mlb"): - return self.mlb.transform(labels) - else: - raise AttributeError("Label binarizer not found") - - def training(self): - lXtr = {} - lYtr = {} - for lang in self.data_langs: - text = self.dataset[lang]["train"]["text"] if self.is_textual else None - img = self.dataset[lang]["train"]["image"] if self.is_visual else None - labels = self.dataset[lang]["train"]["label"] - - lXtr[lang] = {"text": text, "image": img} - lYtr[lang] = self.binarize_labels(labels) - - return lXtr, lYtr - - def test(self): - lXte = {} - lYte = {} - for lang in self.data_langs: - text = self.dataset[lang]["test"]["text"] if self.is_textual else None - img = self.dataset[lang]["test"]["image"] if self.is_visual else None - labels = self.dataset[lang]["test"]["label"] - - lXte[lang] = {"text": text, "image": img} - lYte[lang] = self.binarize_labels(labels) - - return lXte, lYte - - def langs(self): - return self.data_langs - - def num_labels(self): - if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]: - return len(self.labels) - else: - return self.labels - - def save_as_pickle(self, path): - import pickle - - filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl") - with open(filepath, "wb") as f: - print(f"- saving dataset in {filepath}") - pickle.dump(self, f) - def _mask_numbers(data): mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b") mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b") diff --git a/dataManager/utils.py b/dataManager/utils.py index 4da6257..57a7018 100644 --- a/dataManager/utils.py +++ b/dataManager/utils.py @@ -22,6 +22,7 @@ def get_dataset(datasetp_path, args): visual=False, multilabel=False, set_tr_langs=args.tr_langs, - set_te_langs=args.te_langs + set_te_langs=args.te_langs, + reduced=args.reduced ) return dataset diff --git a/main.py b/main.py index ad4abea..56affbe 100644 --- a/main.py +++ b/main.py @@ -185,6 +185,7 @@ if __name__ == "__main__": parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) + parser.add_argument("--reduced", action="store_true", help="run on reduced set of documents") # logging parser.add_argument("--wandb", action="store_true")