diff --git a/src/dataset_builder.py b/src/dataset_builder.py index 3f6732c..9af7b3f 100644 --- a/src/dataset_builder.py +++ b/src/dataset_builder.py @@ -11,6 +11,8 @@ import numpy as np from sklearn.model_selection import train_test_split from scipy.sparse import issparse import itertools +from tqdm import tqdm +import re class MultilingualDataset: @@ -73,10 +75,14 @@ class MultilingualDataset: return self.lXte(), self.lYte() def lXtr(self): - return {lang:Xtr for (lang, ((Xtr,_,_),_)) in self.multiling_dataset.items() if lang in self.langs()} + return {lang: Xtr for (lang, ((Xtr, _, _), _)) in self.multiling_dataset.items() if + lang in self.langs()} + # return {lang:self.mask_numbers(Xtr) for (lang, ((Xtr,_,_),_)) in self.multiling_dataset.items() if lang in self.langs()} def lXte(self): - return {lang:Xte for (lang, (_,(Xte,_,_))) in self.multiling_dataset.items() if lang in self.langs()} + return {lang: Xte for (lang, (_, (Xte, _, _))) in self.multiling_dataset.items() if + lang in self.langs()} + # return {lang:self.mask_numbers(Xte) for (lang, (_,(Xte,_,_))) in self.multiling_dataset.items() if lang in self.langs()} def lYtr(self): return {lang:self.cat_view(Ytr) for (lang, ((_,Ytr,_),_)) in self.multiling_dataset.items() if lang in self.langs()} @@ -129,6 +135,13 @@ class MultilingualDataset: def set_labels(self, labels): self.labels = labels + def mask_numbers(self, data, number_mask='numbermask'): + mask = re.compile(r'\b[0-9][0-9.,-]*\b') + masked = [] + for text in tqdm(data, desc='masking numbers'): + masked.append(mask.sub(number_mask, text)) + return masked + # ---------------------------------------------------------------------------------------------------------------------- # Helpers