gfun_multimodal/dataManager/multiNewsDataset.py

143 lines
4.6 KiB
Python

import re
from os import listdir
from os.path import isdir, join
from dataManager.torchDataset import TorchMultiNewsDataset
# TODO: labels must be aligned between languages
# TODO: remove copyright and also tags (doc.split("More about:")[0])
# TODO: define fn to represent the dataset as a torch Dataset
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug
self.data_dir = data_dir
self.langs = self.get_langs()
self.excluded_langs = excluded_langs
self.lang_multiModalDataset = {}
print(
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
)
self.load_data()
self.print_stats()
def load_data(self):
for lang in self.langs:
if lang not in self.excluded_langs:
self.lang_multiModalDataset[lang] = MultiModalDataset(
lang, join(self.data_dir, lang)
)
def get_langs(self):
from os import listdir
if self.debug:
return ["it", "en"]
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
def print_stats(self):
print(f"[MultiNewsDataset stats]")
# print(f" - langs: {self.langs}")
total_docs = 0
for lang in self.langs:
_len = len(self.lang_multiModalDataset[lang].data)
total_docs += _len
print(
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}"
)
print(f" - total docs: {total_docs}")
def _count_lang_labels(self, data):
lang_labels = set()
for sample in data:
lang_labels.update(sample[-1])
return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id):
raise NotImplementedError
# torch_datasets = []
# for lang, multimodal_dataset in self.lang_multiModalDataset.keys():
# dataset = TorchMultiNewsDataset(
# lang=lang,
# data=multimodal_dataset.get_docs(),
# ids=multimodal_dataset.get_ids(),
# imgs=multimodal_dataset.get_imgs(),
# labels=multimodal_dataset.get_labels(),
# tokenizer_id=tokenizer_id,
# )
# torch_datasets.append(dataset)
# raise NotImplementedError
def save_to_disk(self):
raise NotImplementedError
class MultiModalDataset:
def __init__(self, lang, data_dir):
self.lang = lang
self.data_dir = data_dir
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
self.re_white = re.compile(r" +")
self.data = self.get_docs()
def get_docs(self):
raise NotImplementedError
def get_imgs(self):
raise NotImplementedError
def get_labels(self):
raise NotImplementedError
def get_ids(self):
raise NotImplementedError
def get_docs(self):
data = []
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
for news_folder in news_folder:
if isdir(join(self.data_dir, news_folder)):
fname_doc = f"text.{news_folder.split('.')[-1]}"
with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read()
img = self.get_image()
clean_doc, labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img, labels))
return data
def preprocess_html(self, html_doc):
labels = self._extract_labels(html_doc)
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
return cleaned, labels
def _extract_labels(self, data):
return re.findall(self.re_labels, data)
def _remove_html_tags(self, data):
cleaned = re.sub(self.re_cleaner, "", data)
return cleaned
def _clean_up_str(self, doc):
doc = re.sub(self.re_white, " ", doc)
doc = doc.lstrip()
doc = doc.rstrip()
doc = doc.replace("\n", " ")
doc = doc.replace("\t", " ")
return doc
def get_image(self):
# TODO: implement
pass
if __name__ == "__main__":
from os.path import expanduser
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)