MultiNewsDataset download/save image fn + class for Visual View Generating Function

This commit is contained in:
Andrea Pedrotti 2023-02-08 18:11:53 +01:00
parent 19e4f294db
commit 8325262972
3 changed files with 87 additions and 35 deletions

View File

@ -1,15 +1,28 @@
import os
import sys
sys.path.append(os.getcwd())
import re
from os import listdir
from os.path import isdir, join
from dataManager.torchDataset import TorchMultiNewsDataset
import requests
from bs4 import BeautifulSoup
from PIL import Image
from sklearn.preprocessing import MultiLabelBinarizer
# TODO: labels must be aligned between languages
# TODO: remove copyright and also tags (doc.split("More about:")[0])
# TODO: define fn to represent the dataset as a torch Dataset
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
def get_label_binarizer(cats):
mlb = MultiLabelBinarizer()
mlb.fit([cats])
return mlb
class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug
@ -21,6 +34,8 @@ class MultiNewsDataset:
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
)
self.load_data()
self.all_labels = self.get_labels()
self.label_binarizer = get_label_binarizer(self.all_labels)
self.print_stats()
def load_data(self):
@ -34,47 +49,52 @@ class MultiNewsDataset:
from os import listdir
if self.debug:
return ["it", "en"]
return ["it"]
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
def print_stats(self):
print(f"[MultiNewsDataset stats]")
# print(f" - langs: {self.langs}")
total_docs = 0
for lang in self.langs:
_len = len(self.lang_multiModalDataset[lang].data)
total_docs += _len
print(
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}"
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
)
print(f" - total docs: {total_docs}")
def _count_lang_labels(self, data):
def _count_lang_labels(self, labels):
lang_labels = set()
for sample in data:
lang_labels.update(sample[-1])
for l in labels:
lang_labels.update(l[-1])
return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id):
raise NotImplementedError
# torch_datasets = []
# for lang, multimodal_dataset in self.lang_multiModalDataset.keys():
# dataset = TorchMultiNewsDataset(
# lang=lang,
# data=multimodal_dataset.get_docs(),
# ids=multimodal_dataset.get_ids(),
# imgs=multimodal_dataset.get_imgs(),
# labels=multimodal_dataset.get_labels(),
# tokenizer_id=tokenizer_id,
# )
# torch_datasets.append(dataset)
# raise NotImplementedError
def save_to_disk(self):
raise NotImplementedError
def training(self):
lXtr = {}
lYtr = {}
for lang, data in self.lang_multiModalDataset.items():
lXtr[lang] = data.data
lYtr[lang] = self.label_binarizer.transform(data.labels)
return lXtr, lYtr
def testing(self):
raise NotImplementedError
def get_labels(self):
all_labels = set()
for lang, data in self.lang_multiModalDataset.items():
for label in data.labels:
all_labels.update(label)
return all_labels
class MultiModalDataset:
def __init__(self, lang, data_dir):
@ -83,10 +103,7 @@ class MultiModalDataset:
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
self.re_white = re.compile(r" +")
self.data = self.get_docs()
def get_docs(self):
raise NotImplementedError
self.data, self.labels = self.get_docs()
def get_imgs(self):
raise NotImplementedError
@ -98,19 +115,39 @@ class MultiModalDataset:
raise NotImplementedError
def get_docs(self):
# FIXME: this is a mess
data = []
labels = []
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
for news_folder in news_folder:
if isdir(join(self.data_dir, news_folder)):
fname_doc = f"text.{news_folder.split('.')[-1]}"
with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read()
img = self.get_image()
clean_doc, labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img, labels))
return data
index_path = join(self.data_dir, news_folder, "index.html")
if ".jpg" not in listdir(join(self.data_dir, news_folder)):
img_link, img = self.get_images(index_path)
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
else:
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
clean_doc, doc_labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img))
labels.append(doc_labels)
return data, labels
def save_img(self, path, img):
with open(path, "wb") as f:
f.write(img)
def get_images(self, index_path):
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
imgs = imgs[1]
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
content = requests.get(imgs["src"]).content
return imgs, content
def preprocess_html(self, html_doc):
# TODO: this could be replaced by BeautifulSoup call or something similar
labels = self._extract_labels(html_doc)
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
return cleaned, labels
@ -130,13 +167,11 @@ class MultiModalDataset:
doc = doc.replace("\t", " ")
return doc
def get_image(self):
# TODO: implement
pass
if __name__ == "__main__":
from os.path import expanduser
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
exit()

View File

@ -9,7 +9,7 @@ class VanillaFunGen(ViewGen):
Sebastiani in DOI: https://doi.org/10.1145/3326065
"""
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1):
def __init__(self, base_learner, n_jobs=-1):
"""
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
:param base_learner: naive monolingual learners to be deployed as first-tier
@ -19,7 +19,6 @@ class VanillaFunGen(ViewGen):
"""
print("- init VanillaFun View Generating Function")
self.learners = base_learner
self.first_tier_parameters = first_tier_parameters
self.n_jobs = n_jobs
self.doc_projector = NaivePolylingualClassifier(
base_learner=self.learners,

18
gfun/vgfs/visualGen.py Normal file
View File

@ -0,0 +1,18 @@
from vgfs.viewGen import ViewGen
class VisualGen(ViewGen):
def fit():
raise NotImplemented
def transform(self, lX):
return super().transform(lX)
def fit_transform(self, lX, lY):
return super().fit_transform(lX, lY)
def save_vgf(self, model_id):
return super().save_vgf(model_id)
def save_vgf(self, model_id):
return super().save_vgf(model_id)