sampling GLAMI1-M dataset

This commit is contained in:
Andrea Pedrotti 2023-03-16 18:10:05 +01:00
parent ee38bcda10
commit ee2a9481de
4 changed files with 13 additions and 37 deletions

View File

@ -108,30 +108,13 @@ class gFunDataset:
return dataset, labels, data_langs return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows): def _load_glami(self, dataset_dir, nrows):
# TODO: a better way to get a stratified sampling of the dataset (see: groupby + sample) train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
def _balanced_sample(data, n, remainder=0): test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
import pandas as pd n=int(nrows / 10)
langs = sorted(data.geo.unique().tolist())
dict_n = {lang: n for lang in langs}
dict_n[langs[0]] += remainder
sampled = []
for lang in langs:
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
return pd.concat(sampled, axis=0)
# TODO: set this sampling as determinsitic/dependeing on the seed
lang_nrows = (
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
) # GLAMI 1-M has 13 languages
remainder = (
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
) )
train_split = get_dataframe("train", dataset_dir=dataset_dir) gb_train = train_split.groupby("geo")
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder) gb_test = test_split.groupby("geo")
if self.data_langs is None: if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist()) data_langs = sorted(train_split.geo.unique().tolist())
@ -139,14 +122,6 @@ class gFunDataset:
if self.labels is None: if self.labels is None:
labels = train_split.category_name.unique().tolist() labels = train_split.category_name.unique().tolist()
# TODO: atm test data should contain same languages as train data
test_split = get_dataframe("test", dataset_dir=dataset_dir)
# TODO: atm we're using 1:1 train-test
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
def _format_glami(data_df): def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist() text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist() image = data_df.image_file.tolist()

View File

@ -340,7 +340,7 @@ class EarlyStopping:
self.experiment_name = experiment_name self.experiment_name = experiment_name
def __call__(self, validation, model, epoch): def __call__(self, validation, model, epoch):
if validation > self.best_score: if validation >= self.best_score:
if self.verbose: if self.verbose:
print( print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"

View File

@ -100,7 +100,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-uncased" return "bert-base-uncased"
elif "mbert" == model_name: elif "mbert" == model_name:
return "bert-base-multilingual-uncased" return "bert-base-multilingual-uncased"
elif "xlm" == model_name: elif "xlm-roberta" == model_name:
return "xlm-roberta-base" return "xlm-roberta-base"
elif "mt5" == model_name: elif "mt5" == model_name:
return "google/mt5-small" return "google/mt5-small"
@ -270,8 +270,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
elif "bert" in model_name: elif "bert" in model_name:
if "multilingual" in model_name: if "multilingual" in model_name:
return "mbert" return "mbert"
elif "xlm" in model_name: elif "xlm-roberta" in model_name:
return "xlm" return "xlm-roberta"
else: else:
return model_name return model_name

View File

@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
""" """
TODO: TODO:
- Transformers VGFs: - Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method - freeze params method
- General: - General:
[!] zero-shot setup [!] zero-shot setup
@ -177,17 +178,17 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false") parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean") parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--textual_trf_name", type=str, default="mbert") parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--textual_lr", type=float, default=1e-4)
parser.add_argument("--textual_lr", type=float, default=1e-5)
parser.add_argument("--visual_lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5) parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters -------------- # Visual Transformer parameters --------------
parser.add_argument("--visual_trf_name", type=str, default="vit") parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
args = parser.parse_args() args = parser.parse_args()