gfun_multimodal/dataManager/multiNewsDataset.py

191 lines
6.3 KiB
Python
Raw Normal View History

import os
import sys
sys.path.append(os.getcwd())
2023-02-07 18:40:17 +01:00
import re
from os import listdir
from os.path import isdir, join
import requests
from bs4 import BeautifulSoup
from PIL import Image
from sklearn.preprocessing import MultiLabelBinarizer
2023-02-07 18:40:17 +01:00
# TODO: labels must be aligned between languages
# TODO: remove copyright and also tags (doc.split("More about:")[0])
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
def get_label_binarizer(cats):
mlb = MultiLabelBinarizer()
mlb.fit([cats])
return mlb
2023-02-07 18:40:17 +01:00
class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug
self.data_dir = data_dir
self.dataset_langs = self.get_langs()
2023-02-07 18:40:17 +01:00
self.excluded_langs = excluded_langs
self.lang_multiModalDataset = {}
print(
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]"
2023-02-07 18:40:17 +01:00
)
self.load_data()
self.all_labels = self.get_labels()
self.label_binarizer = get_label_binarizer(self.all_labels)
2023-02-07 18:40:17 +01:00
self.print_stats()
def load_data(self):
for lang in self.dataset_langs:
2023-02-07 18:40:17 +01:00
if lang not in self.excluded_langs:
self.lang_multiModalDataset[lang] = MultiModalDataset(
lang, join(self.data_dir, lang)
)
def langs(self):
return [l for l in self.dataset_langs if l not in self.excluded_langs]
return self.get_langs()
2023-02-07 18:40:17 +01:00
def get_langs(self):
from os import listdir
if self.debug:
2023-02-09 16:55:06 +01:00
return ["it", "en"]
2023-02-07 18:40:17 +01:00
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
def print_stats(self):
print(f"[MultiNewsDataset stats]")
total_docs = 0
for lang in self.dataset_langs:
if lang not in self.excluded_langs:
_len = len(self.lang_multiModalDataset[lang].data)
total_docs += _len
print(
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
)
print(f" - total docs: {total_docs}\n")
2023-02-07 18:40:17 +01:00
def _count_lang_labels(self, labels):
2023-02-07 18:40:17 +01:00
lang_labels = set()
for l in labels:
2023-02-09 16:55:06 +01:00
lang_labels.update(l)
2023-02-07 18:40:17 +01:00
return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id):
raise NotImplementedError
def save_to_disk(self):
raise NotImplementedError
def training(self):
# TODO: this is a (working) mess - clean this up
lXtr = {}
lYtr = {}
for lang, data in self.lang_multiModalDataset.items():
_data = [clean_text for _, clean_text, _, _ in data.data]
lXtr[lang] = _data
lYtr = {
lang: self.label_binarizer.transform(data.labels)
for lang, data in self.lang_multiModalDataset.items()
}
return lXtr, lYtr
def testing(self):
raise NotImplementedError
def get_labels(self):
all_labels = set()
for lang, data in self.lang_multiModalDataset.items():
for label in data.labels:
all_labels.update(label)
return all_labels
2023-02-07 18:40:17 +01:00
class MultiModalDataset:
def __init__(self, lang, data_dir):
self.lang = lang
self.data_dir = data_dir
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
self.re_white = re.compile(r" +")
self.data, self.labels = self.get_docs()
2023-02-07 18:40:17 +01:00
def get_imgs(self):
raise NotImplementedError
def get_labels(self):
raise NotImplementedError
def get_ids(self):
raise NotImplementedError
def get_docs(self):
# FIXME: this is a mess
2023-02-07 18:40:17 +01:00
data = []
labels = []
2023-02-07 18:40:17 +01:00
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
for news_folder in news_folder:
if isdir(join(self.data_dir, news_folder)):
fname_doc = f"text.{news_folder.split('.')[-1]}"
with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read()
index_path = join(self.data_dir, news_folder, "index.html")
2023-02-09 16:55:06 +01:00
if not any(
File.endswith(".jpg")
for File in listdir(join(self.data_dir, news_folder))
):
img_link, img = self.get_images(index_path)
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
2023-02-09 16:55:06 +01:00
# TODO: convert img to PIL image
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
clean_doc, doc_labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img))
labels.append(doc_labels)
return data, labels
def save_img(self, path, img):
with open(path, "wb") as f:
f.write(img)
def get_images(self, index_path):
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
imgs = imgs[1]
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
content = requests.get(imgs["src"]).content
return imgs, content
2023-02-07 18:40:17 +01:00
def preprocess_html(self, html_doc):
# TODO: this could be replaced by BeautifulSoup call or something similar
2023-02-07 18:40:17 +01:00
labels = self._extract_labels(html_doc)
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
return cleaned, labels
def _extract_labels(self, data):
return re.findall(self.re_labels, data)
def _remove_html_tags(self, data):
cleaned = re.sub(self.re_cleaner, "", data)
return cleaned
def _clean_up_str(self, doc):
doc = re.sub(self.re_white, " ", doc)
doc = doc.lstrip()
doc = doc.rstrip()
doc = doc.replace("\n", " ")
doc = doc.replace("\t", " ")
return doc
if __name__ == "__main__":
from os.path import expanduser
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
exit()