import re from os import listdir from os.path import isdir, join from dataManager.torchDataset import TorchMultiNewsDataset # TODO: labels must be aligned between languages # TODO: remove copyright and also tags (doc.split("More about:")[0]) # TODO: define fn to represent the dataset as a torch Dataset # TODO: this should be a instance of a abstract MultimodalMultilingualDataset class MultiNewsDataset: def __init__(self, data_dir, excluded_langs=[], debug=False): self.debug = debug self.data_dir = data_dir self.langs = self.get_langs() self.excluded_langs = excluded_langs self.lang_multiModalDataset = {} print( f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]" ) self.load_data() self.print_stats() def load_data(self): for lang in self.langs: if lang not in self.excluded_langs: self.lang_multiModalDataset[lang] = MultiModalDataset( lang, join(self.data_dir, lang) ) def get_langs(self): from os import listdir if self.debug: return ["it", "en"] return tuple(sorted([folder for folder in listdir(self.data_dir)])) def print_stats(self): print(f"[MultiNewsDataset stats]") # print(f" - langs: {self.langs}") total_docs = 0 for lang in self.langs: _len = len(self.lang_multiModalDataset[lang].data) total_docs += _len print( f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}" ) print(f" - total docs: {total_docs}") def _count_lang_labels(self, data): lang_labels = set() for sample in data: lang_labels.update(sample[-1]) return len(lang_labels) def export_to_torch_dataset(self, tokenizer_id): raise NotImplementedError # torch_datasets = [] # for lang, multimodal_dataset in self.lang_multiModalDataset.keys(): # dataset = TorchMultiNewsDataset( # lang=lang, # data=multimodal_dataset.get_docs(), # ids=multimodal_dataset.get_ids(), # imgs=multimodal_dataset.get_imgs(), # labels=multimodal_dataset.get_labels(), # tokenizer_id=tokenizer_id, # ) # torch_datasets.append(dataset) # raise NotImplementedError def save_to_disk(self): raise NotImplementedError class MultiModalDataset: def __init__(self, lang, data_dir): self.lang = lang self.data_dir = data_dir self.re_labels = re.compile(r"(.+?)<\/a>") self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});") self.re_white = re.compile(r" +") self.data = self.get_docs() def get_docs(self): raise NotImplementedError def get_imgs(self): raise NotImplementedError def get_labels(self): raise NotImplementedError def get_ids(self): raise NotImplementedError def get_docs(self): data = [] news_folder = [doc_folder for doc_folder in listdir(self.data_dir)] for news_folder in news_folder: if isdir(join(self.data_dir, news_folder)): fname_doc = f"text.{news_folder.split('.')[-1]}" with open(join(self.data_dir, news_folder, fname_doc)) as f: html_doc = f.read() img = self.get_image() clean_doc, labels = self.preprocess_html(html_doc) data.append((fname_doc, clean_doc, html_doc, img, labels)) return data def preprocess_html(self, html_doc): labels = self._extract_labels(html_doc) cleaned = self._clean_up_str(self._remove_html_tags(html_doc)) return cleaned, labels def _extract_labels(self, data): return re.findall(self.re_labels, data) def _remove_html_tags(self, data): cleaned = re.sub(self.re_cleaner, "", data) return cleaned def _clean_up_str(self, doc): doc = re.sub(self.re_white, " ", doc) doc = doc.lstrip() doc = doc.rstrip() doc = doc.replace("\n", " ") doc = doc.replace("\t", " ") return doc def get_image(self): # TODO: implement pass if __name__ == "__main__": from os.path import expanduser _dataset_path_hardcoded = "~/datasets/MultiNews/20110730/" dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)