minor updates
This commit is contained in:
parent
875af6d362
commit
5d07e579e4
|
@ -24,13 +24,15 @@ class SimpleGfunDataset:
|
|||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=None,
|
||||
set_te_langs=None
|
||||
set_te_langs=None,
|
||||
reduced=False
|
||||
):
|
||||
self.name = dataset_name
|
||||
self.datadir = os.path.expanduser(datadir)
|
||||
self.textual = textual
|
||||
self.visual = visual
|
||||
self.multilabel = multilabel
|
||||
self.reduced = reduced
|
||||
self.load_csv(set_tr_langs, set_te_langs)
|
||||
self.print_stats()
|
||||
|
||||
|
@ -51,7 +53,7 @@ class SimpleGfunDataset:
|
|||
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||
|
||||
def load_csv(self, set_tr_langs, set_te_langs):
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.csv" if not self.reduced else "train.small.csv"))
|
||||
try:
|
||||
stratified = "class"
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.label)
|
||||
|
@ -59,7 +61,7 @@ class SimpleGfunDataset:
|
|||
stratified = "lang"
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang)
|
||||
print(f"- dataset stratified by {stratified}")
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.small.csv" if not self.reduced else "test.small.csv"))
|
||||
self._set_langs (train, test, set_tr_langs, set_te_langs)
|
||||
self._set_labels(_data_tr)
|
||||
self.full_train = _data_tr
|
||||
|
@ -140,168 +142,6 @@ class SimpleGfunDataset:
|
|||
return one_hot_matrix
|
||||
|
||||
|
||||
class gFunDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
is_textual,
|
||||
is_visual,
|
||||
is_multilabel,
|
||||
labels=None,
|
||||
nrows=None,
|
||||
data_langs=None,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.data_langs = data_langs
|
||||
self.is_textual = is_textual
|
||||
self.is_visual = is_visual
|
||||
self.is_multilabel = is_multilabel
|
||||
self.labels = labels
|
||||
self.nrows = nrows
|
||||
self.dataset = {}
|
||||
self._load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
mlb = f"Labels are already binarized for {self.dataset_name} dataset"
|
||||
elif self.is_multilabel:
|
||||
mlb = MultiLabelBinarizer()
|
||||
mlb.fit([labels])
|
||||
else:
|
||||
mlb = LabelBinarizer()
|
||||
mlb.fit(labels)
|
||||
return mlb
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"- Loading dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rai"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||
dataset_dir=self.dataset_dir,
|
||||
nrows=self.nrows)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
self.show_dimension()
|
||||
return
|
||||
|
||||
def show_dimension(self):
|
||||
print(f"\n[Dataset: {self.dataset_name.upper()}]")
|
||||
for lang, data in self.dataset.items():
|
||||
print(
|
||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||
)
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
print(f"-- Labels: {self.labels}")
|
||||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
||||
def _load_multilingual(self, dataset_dir, nrows, dataset_name="rai"):
|
||||
if "csv" in dataset_dir:
|
||||
old_dataset = MultilingualDataset(dataset_name="rai").from_csv(
|
||||
path_tr=os.path.expanduser(os.path.join(dataset_dir, "train.small.csv")),
|
||||
path_te=os.path.expanduser(os.path.join(dataset_dir, "test.small.csv"))
|
||||
)
|
||||
if nrows is not None:
|
||||
if dataset_name == "cls":
|
||||
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||
else:
|
||||
old_dataset.reduce_data(langs=["en", "it", "fr"], maxn=nrows)
|
||||
labels = old_dataset.num_labels()
|
||||
data_langs = old_dataset.langs()
|
||||
|
||||
def _format_multilingual(data):
|
||||
text = data[0]
|
||||
image = None
|
||||
labels = data[1]
|
||||
return {"text": text, "image": image, "label": labels}
|
||||
|
||||
dataset = {
|
||||
k: {"train": _format_multilingual(v[0]), "test": _format_multilingual(v[1])}
|
||||
for k, v in old_dataset.multiling_dataset.items()
|
||||
}
|
||||
return dataset, labels, data_langs
|
||||
|
||||
def _load_glami(self, dataset_dir, nrows):
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
|
||||
n=int(nrows / 10)
|
||||
)
|
||||
|
||||
gb_train = train_split.groupby("geo")
|
||||
gb_test = test_split.groupby("geo")
|
||||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
def _format_glami(data_df):
|
||||
text = (data_df.name + " " + data_df.description).tolist()
|
||||
image = data_df.image_file.tolist()
|
||||
labels = data_df.category_name.tolist()
|
||||
return {"text": text, "image": image, "label": labels}
|
||||
|
||||
dataset = {
|
||||
lang: {
|
||||
"train": _format_glami(data_tr),
|
||||
"test": _format_glami(gb_test.get_group(lang)),
|
||||
}
|
||||
for lang, data_tr in gb_train
|
||||
if lang in data_langs
|
||||
}
|
||||
|
||||
return dataset, labels, data_langs
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
# labels are already binarized for rcv1-2 dataset
|
||||
return labels
|
||||
if hasattr(self, "mlb"):
|
||||
return self.mlb.transform(labels)
|
||||
else:
|
||||
raise AttributeError("Label binarizer not found")
|
||||
|
||||
def training(self):
|
||||
lXtr = {}
|
||||
lYtr = {}
|
||||
for lang in self.data_langs:
|
||||
text = self.dataset[lang]["train"]["text"] if self.is_textual else None
|
||||
img = self.dataset[lang]["train"]["image"] if self.is_visual else None
|
||||
labels = self.dataset[lang]["train"]["label"]
|
||||
|
||||
lXtr[lang] = {"text": text, "image": img}
|
||||
lYtr[lang] = self.binarize_labels(labels)
|
||||
|
||||
return lXtr, lYtr
|
||||
|
||||
def test(self):
|
||||
lXte = {}
|
||||
lYte = {}
|
||||
for lang in self.data_langs:
|
||||
text = self.dataset[lang]["test"]["text"] if self.is_textual else None
|
||||
img = self.dataset[lang]["test"]["image"] if self.is_visual else None
|
||||
labels = self.dataset[lang]["test"]["label"]
|
||||
|
||||
lXte[lang] = {"text": text, "image": img}
|
||||
lYte[lang] = self.binarize_labels(labels)
|
||||
|
||||
return lXte, lYte
|
||||
|
||||
def langs(self):
|
||||
return self.data_langs
|
||||
|
||||
def num_labels(self):
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
return len(self.labels)
|
||||
else:
|
||||
return self.labels
|
||||
|
||||
def save_as_pickle(self, path):
|
||||
import pickle
|
||||
|
||||
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
|
||||
with open(filepath, "wb") as f:
|
||||
print(f"- saving dataset in {filepath}")
|
||||
pickle.dump(self, f)
|
||||
|
||||
def _mask_numbers(data):
|
||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
|
||||
|
|
|
@ -22,6 +22,7 @@ def get_dataset(datasetp_path, args):
|
|||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=args.tr_langs,
|
||||
set_te_langs=args.te_langs
|
||||
set_te_langs=args.te_langs,
|
||||
reduced=args.reduced
|
||||
)
|
||||
return dataset
|
||||
|
|
1
main.py
1
main.py
|
@ -185,6 +185,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
parser.add_argument("--reduced", action="store_true", help="run on reduced set of documents")
|
||||
# logging
|
||||
parser.add_argument("--wandb", action="store_true")
|
||||
|
||||
|
|
Loading…
Reference in New Issue