sampling GLAMI1-M dataset
This commit is contained in:
parent
ee38bcda10
commit
ee2a9481de
|
|
@ -108,30 +108,13 @@ class gFunDataset:
|
||||||
return dataset, labels, data_langs
|
return dataset, labels, data_langs
|
||||||
|
|
||||||
def _load_glami(self, dataset_dir, nrows):
|
def _load_glami(self, dataset_dir, nrows):
|
||||||
# TODO: a better way to get a stratified sampling of the dataset (see: groupby + sample)
|
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
|
||||||
def _balanced_sample(data, n, remainder=0):
|
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
|
||||||
import pandas as pd
|
n=int(nrows / 10)
|
||||||
|
|
||||||
langs = sorted(data.geo.unique().tolist())
|
|
||||||
dict_n = {lang: n for lang in langs}
|
|
||||||
dict_n[langs[0]] += remainder
|
|
||||||
|
|
||||||
sampled = []
|
|
||||||
for lang in langs:
|
|
||||||
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
|
||||||
|
|
||||||
return pd.concat(sampled, axis=0)
|
|
||||||
|
|
||||||
# TODO: set this sampling as determinsitic/dependeing on the seed
|
|
||||||
lang_nrows = (
|
|
||||||
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
|
||||||
) # GLAMI 1-M has 13 languages
|
|
||||||
remainder = (
|
|
||||||
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
gb_train = train_split.groupby("geo")
|
||||||
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
gb_test = test_split.groupby("geo")
|
||||||
|
|
||||||
if self.data_langs is None:
|
if self.data_langs is None:
|
||||||
data_langs = sorted(train_split.geo.unique().tolist())
|
data_langs = sorted(train_split.geo.unique().tolist())
|
||||||
|
|
@ -139,14 +122,6 @@ class gFunDataset:
|
||||||
if self.labels is None:
|
if self.labels is None:
|
||||||
labels = train_split.category_name.unique().tolist()
|
labels = train_split.category_name.unique().tolist()
|
||||||
|
|
||||||
# TODO: atm test data should contain same languages as train data
|
|
||||||
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
|
||||||
# TODO: atm we're using 1:1 train-test
|
|
||||||
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
|
||||||
|
|
||||||
gb_train = train_split.groupby("geo")
|
|
||||||
gb_test = test_split.groupby("geo")
|
|
||||||
|
|
||||||
def _format_glami(data_df):
|
def _format_glami(data_df):
|
||||||
text = (data_df.name + " " + data_df.description).tolist()
|
text = (data_df.name + " " + data_df.description).tolist()
|
||||||
image = data_df.image_file.tolist()
|
image = data_df.image_file.tolist()
|
||||||
|
|
|
||||||
|
|
@ -340,7 +340,7 @@ class EarlyStopping:
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
|
|
||||||
def __call__(self, validation, model, epoch):
|
def __call__(self, validation, model, epoch):
|
||||||
if validation > self.best_score:
|
if validation >= self.best_score:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
return "bert-base-uncased"
|
return "bert-base-uncased"
|
||||||
elif "mbert" == model_name:
|
elif "mbert" == model_name:
|
||||||
return "bert-base-multilingual-uncased"
|
return "bert-base-multilingual-uncased"
|
||||||
elif "xlm" == model_name:
|
elif "xlm-roberta" == model_name:
|
||||||
return "xlm-roberta-base"
|
return "xlm-roberta-base"
|
||||||
elif "mt5" == model_name:
|
elif "mt5" == model_name:
|
||||||
return "google/mt5-small"
|
return "google/mt5-small"
|
||||||
|
|
@ -270,8 +270,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
elif "bert" in model_name:
|
elif "bert" in model_name:
|
||||||
if "multilingual" in model_name:
|
if "multilingual" in model_name:
|
||||||
return "mbert"
|
return "mbert"
|
||||||
elif "xlm" in model_name:
|
elif "xlm-roberta" in model_name:
|
||||||
return "xlm"
|
return "xlm-roberta"
|
||||||
else:
|
else:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
|
||||||
7
main.py
7
main.py
|
|
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- Transformers VGFs:
|
- Transformers VGFs:
|
||||||
|
- scheduler with warmup and cosine
|
||||||
- freeze params method
|
- freeze params method
|
||||||
- General:
|
- General:
|
||||||
[!] zero-shot setup
|
[!] zero-shot setup
|
||||||
|
|
@ -177,17 +178,17 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--features", action="store_false")
|
parser.add_argument("--features", action="store_false")
|
||||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--textual_lr", type=float, default=1e-4)
|
||||||
parser.add_argument("--textual_lr", type=float, default=1e-5)
|
|
||||||
parser.add_argument("--visual_lr", type=float, default=1e-5)
|
|
||||||
parser.add_argument("--max_length", type=int, default=128)
|
parser.add_argument("--max_length", type=int, default=128)
|
||||||
parser.add_argument("--patience", type=int, default=5)
|
parser.add_argument("--patience", type=int, default=5)
|
||||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||||
# Visual Transformer parameters --------------
|
# Visual Transformer parameters --------------
|
||||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||||
|
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue