MultiNewsDataset download/save image fn + class for Visual View Generating Function

This commit is contained in:
Andrea Pedrotti 2023-02-08 18:11:53 +01:00
parent 19e4f294db
commit 8325262972
3 changed files with 87 additions and 35 deletions

View File

@ -1,15 +1,28 @@
import os
import sys
sys.path.append(os.getcwd())
import re import re
from os import listdir from os import listdir
from os.path import isdir, join from os.path import isdir, join
from dataManager.torchDataset import TorchMultiNewsDataset import requests
from bs4 import BeautifulSoup
from PIL import Image
from sklearn.preprocessing import MultiLabelBinarizer
# TODO: labels must be aligned between languages # TODO: labels must be aligned between languages
# TODO: remove copyright and also tags (doc.split("More about:")[0]) # TODO: remove copyright and also tags (doc.split("More about:")[0])
# TODO: define fn to represent the dataset as a torch Dataset
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset # TODO: this should be a instance of a abstract MultimodalMultilingualDataset
def get_label_binarizer(cats):
mlb = MultiLabelBinarizer()
mlb.fit([cats])
return mlb
class MultiNewsDataset: class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False): def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug self.debug = debug
@ -21,6 +34,8 @@ class MultiNewsDataset:
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]" f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
) )
self.load_data() self.load_data()
self.all_labels = self.get_labels()
self.label_binarizer = get_label_binarizer(self.all_labels)
self.print_stats() self.print_stats()
def load_data(self): def load_data(self):
@ -34,47 +49,52 @@ class MultiNewsDataset:
from os import listdir from os import listdir
if self.debug: if self.debug:
return ["it", "en"] return ["it"]
return tuple(sorted([folder for folder in listdir(self.data_dir)])) return tuple(sorted([folder for folder in listdir(self.data_dir)]))
def print_stats(self): def print_stats(self):
print(f"[MultiNewsDataset stats]") print(f"[MultiNewsDataset stats]")
# print(f" - langs: {self.langs}")
total_docs = 0 total_docs = 0
for lang in self.langs: for lang in self.langs:
_len = len(self.lang_multiModalDataset[lang].data) _len = len(self.lang_multiModalDataset[lang].data)
total_docs += _len total_docs += _len
print( print(
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}" f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
) )
print(f" - total docs: {total_docs}") print(f" - total docs: {total_docs}")
def _count_lang_labels(self, data): def _count_lang_labels(self, labels):
lang_labels = set() lang_labels = set()
for sample in data: for l in labels:
lang_labels.update(sample[-1]) lang_labels.update(l[-1])
return len(lang_labels) return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id): def export_to_torch_dataset(self, tokenizer_id):
raise NotImplementedError raise NotImplementedError
# torch_datasets = []
# for lang, multimodal_dataset in self.lang_multiModalDataset.keys():
# dataset = TorchMultiNewsDataset(
# lang=lang,
# data=multimodal_dataset.get_docs(),
# ids=multimodal_dataset.get_ids(),
# imgs=multimodal_dataset.get_imgs(),
# labels=multimodal_dataset.get_labels(),
# tokenizer_id=tokenizer_id,
# )
# torch_datasets.append(dataset)
# raise NotImplementedError
def save_to_disk(self): def save_to_disk(self):
raise NotImplementedError raise NotImplementedError
def training(self):
lXtr = {}
lYtr = {}
for lang, data in self.lang_multiModalDataset.items():
lXtr[lang] = data.data
lYtr[lang] = self.label_binarizer.transform(data.labels)
return lXtr, lYtr
def testing(self):
raise NotImplementedError
def get_labels(self):
all_labels = set()
for lang, data in self.lang_multiModalDataset.items():
for label in data.labels:
all_labels.update(label)
return all_labels
class MultiModalDataset: class MultiModalDataset:
def __init__(self, lang, data_dir): def __init__(self, lang, data_dir):
@ -83,10 +103,7 @@ class MultiModalDataset:
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>") self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});") self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
self.re_white = re.compile(r" +") self.re_white = re.compile(r" +")
self.data = self.get_docs() self.data, self.labels = self.get_docs()
def get_docs(self):
raise NotImplementedError
def get_imgs(self): def get_imgs(self):
raise NotImplementedError raise NotImplementedError
@ -98,19 +115,39 @@ class MultiModalDataset:
raise NotImplementedError raise NotImplementedError
def get_docs(self): def get_docs(self):
# FIXME: this is a mess
data = [] data = []
labels = []
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)] news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
for news_folder in news_folder: for news_folder in news_folder:
if isdir(join(self.data_dir, news_folder)): if isdir(join(self.data_dir, news_folder)):
fname_doc = f"text.{news_folder.split('.')[-1]}" fname_doc = f"text.{news_folder.split('.')[-1]}"
with open(join(self.data_dir, news_folder, fname_doc)) as f: with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read() html_doc = f.read()
img = self.get_image() index_path = join(self.data_dir, news_folder, "index.html")
clean_doc, labels = self.preprocess_html(html_doc) if ".jpg" not in listdir(join(self.data_dir, news_folder)):
data.append((fname_doc, clean_doc, html_doc, img, labels)) img_link, img = self.get_images(index_path)
return data self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
else:
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
clean_doc, doc_labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img))
labels.append(doc_labels)
return data, labels
def save_img(self, path, img):
with open(path, "wb") as f:
f.write(img)
def get_images(self, index_path):
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
imgs = imgs[1]
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
content = requests.get(imgs["src"]).content
return imgs, content
def preprocess_html(self, html_doc): def preprocess_html(self, html_doc):
# TODO: this could be replaced by BeautifulSoup call or something similar
labels = self._extract_labels(html_doc) labels = self._extract_labels(html_doc)
cleaned = self._clean_up_str(self._remove_html_tags(html_doc)) cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
return cleaned, labels return cleaned, labels
@ -130,13 +167,11 @@ class MultiModalDataset:
doc = doc.replace("\t", " ") doc = doc.replace("\t", " ")
return doc return doc
def get_image(self):
# TODO: implement
pass
if __name__ == "__main__": if __name__ == "__main__":
from os.path import expanduser from os.path import expanduser
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/" _dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
exit()

View File

@ -9,7 +9,7 @@ class VanillaFunGen(ViewGen):
Sebastiani in DOI: https://doi.org/10.1145/3326065 Sebastiani in DOI: https://doi.org/10.1145/3326065
""" """
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1): def __init__(self, base_learner, n_jobs=-1):
""" """
Init Posterior Probabilities embedder (i.e., VanillaFunGen) Init Posterior Probabilities embedder (i.e., VanillaFunGen)
:param base_learner: naive monolingual learners to be deployed as first-tier :param base_learner: naive monolingual learners to be deployed as first-tier
@ -19,7 +19,6 @@ class VanillaFunGen(ViewGen):
""" """
print("- init VanillaFun View Generating Function") print("- init VanillaFun View Generating Function")
self.learners = base_learner self.learners = base_learner
self.first_tier_parameters = first_tier_parameters
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.doc_projector = NaivePolylingualClassifier( self.doc_projector = NaivePolylingualClassifier(
base_learner=self.learners, base_learner=self.learners,

18
gfun/vgfs/visualGen.py Normal file
View File

@ -0,0 +1,18 @@
from vgfs.viewGen import ViewGen
class VisualGen(ViewGen):
def fit():
raise NotImplemented
def transform(self, lX):
return super().transform(lX)
def fit_transform(self, lX, lY):
return super().fit_transform(lX, lY)
def save_vgf(self, model_id):
return super().save_vgf(model_id)
def save_vgf(self, model_id):
return super().save_vgf(model_id)