67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class MultilingualDatasetTorch(Dataset):
|
|
def __init__(self, lX, lY, split="train"):
|
|
self.lX = lX
|
|
self.lY = lY
|
|
self.split = split
|
|
self.langs = []
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
|
if self.split != "whole":
|
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
self.langs = sum(
|
|
[
|
|
v
|
|
for v in {
|
|
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
|
}.values()
|
|
],
|
|
[],
|
|
)
|
|
|
|
return self
|
|
|
|
def __len__(self):
|
|
return len(self.X)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == "whole":
|
|
return self.X[index], self.langs[index]
|
|
return self.X[index], self.Y[index], self.langs[index]
|
|
|
|
|
|
class MultimodalDatasetTorch(Dataset):
|
|
def __init__(self, lX, lY, split="train"):
|
|
self.lX = lX
|
|
self.lY = lY
|
|
self.split = split
|
|
self.langs = []
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
|
if self.split != "whole":
|
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
self.langs = sum(
|
|
[
|
|
v
|
|
for v in {
|
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
|
}.values()
|
|
],
|
|
[],
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.X)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == "whole":
|
|
return self.X[index], self.langs[index]
|
|
return self.X[index], self.Y[index], self.langs[index]
|