from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader class TransformerGen: """Base class for all transformers. It implements the basic methods for the creation of the datasets, datalaoders and the train-val split method. It is designed to be used with MultilingualDataset in the form of dictioanries {lang: data} """ def __init__(self): self.datasets = {} def build_dataloader( self, lX, lY, torchDataset, processor_fn, batch_size, split="train", shuffle=False, ): l_tokenized = {lang: processor_fn(data) for lang, data in lX.items()} self.datasets[split] = torchDataset(l_tokenized, lY, split=split) return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) def get_train_val_data(self, lX, lY, split=0.2, seed=42): tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y val_lX[lang] = val_X val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY