143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
|
import re
|
||
|
from os import listdir
|
||
|
from os.path import isdir, join
|
||
|
|
||
|
from dataManager.torchDataset import TorchMultiNewsDataset
|
||
|
|
||
|
# TODO: labels must be aligned between languages
|
||
|
# TODO: remove copyright and also tags (doc.split("More about:")[0])
|
||
|
# TODO: define fn to represent the dataset as a torch Dataset
|
||
|
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
|
||
|
|
||
|
|
||
|
class MultiNewsDataset:
|
||
|
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
||
|
self.debug = debug
|
||
|
self.data_dir = data_dir
|
||
|
self.langs = self.get_langs()
|
||
|
self.excluded_langs = excluded_langs
|
||
|
self.lang_multiModalDataset = {}
|
||
|
print(
|
||
|
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
|
||
|
)
|
||
|
self.load_data()
|
||
|
self.print_stats()
|
||
|
|
||
|
def load_data(self):
|
||
|
for lang in self.langs:
|
||
|
if lang not in self.excluded_langs:
|
||
|
self.lang_multiModalDataset[lang] = MultiModalDataset(
|
||
|
lang, join(self.data_dir, lang)
|
||
|
)
|
||
|
|
||
|
def get_langs(self):
|
||
|
from os import listdir
|
||
|
|
||
|
if self.debug:
|
||
|
return ["it", "en"]
|
||
|
|
||
|
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
|
||
|
|
||
|
def print_stats(self):
|
||
|
print(f"[MultiNewsDataset stats]")
|
||
|
# print(f" - langs: {self.langs}")
|
||
|
total_docs = 0
|
||
|
for lang in self.langs:
|
||
|
_len = len(self.lang_multiModalDataset[lang].data)
|
||
|
total_docs += _len
|
||
|
print(
|
||
|
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}"
|
||
|
)
|
||
|
print(f" - total docs: {total_docs}")
|
||
|
|
||
|
def _count_lang_labels(self, data):
|
||
|
lang_labels = set()
|
||
|
for sample in data:
|
||
|
lang_labels.update(sample[-1])
|
||
|
return len(lang_labels)
|
||
|
|
||
|
def export_to_torch_dataset(self, tokenizer_id):
|
||
|
raise NotImplementedError
|
||
|
# torch_datasets = []
|
||
|
# for lang, multimodal_dataset in self.lang_multiModalDataset.keys():
|
||
|
# dataset = TorchMultiNewsDataset(
|
||
|
# lang=lang,
|
||
|
# data=multimodal_dataset.get_docs(),
|
||
|
# ids=multimodal_dataset.get_ids(),
|
||
|
# imgs=multimodal_dataset.get_imgs(),
|
||
|
# labels=multimodal_dataset.get_labels(),
|
||
|
# tokenizer_id=tokenizer_id,
|
||
|
# )
|
||
|
# torch_datasets.append(dataset)
|
||
|
|
||
|
# raise NotImplementedError
|
||
|
|
||
|
def save_to_disk(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class MultiModalDataset:
|
||
|
def __init__(self, lang, data_dir):
|
||
|
self.lang = lang
|
||
|
self.data_dir = data_dir
|
||
|
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
|
||
|
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
||
|
self.re_white = re.compile(r" +")
|
||
|
self.data = self.get_docs()
|
||
|
|
||
|
def get_docs(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_imgs(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_labels(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_ids(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_docs(self):
|
||
|
data = []
|
||
|
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
|
||
|
for news_folder in news_folder:
|
||
|
if isdir(join(self.data_dir, news_folder)):
|
||
|
fname_doc = f"text.{news_folder.split('.')[-1]}"
|
||
|
with open(join(self.data_dir, news_folder, fname_doc)) as f:
|
||
|
html_doc = f.read()
|
||
|
img = self.get_image()
|
||
|
clean_doc, labels = self.preprocess_html(html_doc)
|
||
|
data.append((fname_doc, clean_doc, html_doc, img, labels))
|
||
|
return data
|
||
|
|
||
|
def preprocess_html(self, html_doc):
|
||
|
labels = self._extract_labels(html_doc)
|
||
|
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
|
||
|
return cleaned, labels
|
||
|
|
||
|
def _extract_labels(self, data):
|
||
|
return re.findall(self.re_labels, data)
|
||
|
|
||
|
def _remove_html_tags(self, data):
|
||
|
cleaned = re.sub(self.re_cleaner, "", data)
|
||
|
return cleaned
|
||
|
|
||
|
def _clean_up_str(self, doc):
|
||
|
doc = re.sub(self.re_white, " ", doc)
|
||
|
doc = doc.lstrip()
|
||
|
doc = doc.rstrip()
|
||
|
doc = doc.replace("\n", " ")
|
||
|
doc = doc.replace("\t", " ")
|
||
|
return doc
|
||
|
|
||
|
def get_image(self):
|
||
|
# TODO: implement
|
||
|
pass
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
from os.path import expanduser
|
||
|
|
||
|
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
||
|
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|