MultiNewsDataset download/save image fn + class for Visual View Generating Function
This commit is contained in:
parent
19e4f294db
commit
8325262972
|
@ -1,15 +1,28 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from os import listdir
|
from os import listdir
|
||||||
from os.path import isdir, join
|
from os.path import isdir, join
|
||||||
|
|
||||||
from dataManager.torchDataset import TorchMultiNewsDataset
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from PIL import Image
|
||||||
|
from sklearn.preprocessing import MultiLabelBinarizer
|
||||||
|
|
||||||
# TODO: labels must be aligned between languages
|
# TODO: labels must be aligned between languages
|
||||||
# TODO: remove copyright and also tags (doc.split("More about:")[0])
|
# TODO: remove copyright and also tags (doc.split("More about:")[0])
|
||||||
# TODO: define fn to represent the dataset as a torch Dataset
|
|
||||||
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
|
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_binarizer(cats):
|
||||||
|
mlb = MultiLabelBinarizer()
|
||||||
|
mlb.fit([cats])
|
||||||
|
return mlb
|
||||||
|
|
||||||
|
|
||||||
class MultiNewsDataset:
|
class MultiNewsDataset:
|
||||||
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
@ -21,6 +34,8 @@ class MultiNewsDataset:
|
||||||
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
|
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]"
|
||||||
)
|
)
|
||||||
self.load_data()
|
self.load_data()
|
||||||
|
self.all_labels = self.get_labels()
|
||||||
|
self.label_binarizer = get_label_binarizer(self.all_labels)
|
||||||
self.print_stats()
|
self.print_stats()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
|
@ -34,47 +49,52 @@ class MultiNewsDataset:
|
||||||
from os import listdir
|
from os import listdir
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
return ["it", "en"]
|
return ["it"]
|
||||||
|
|
||||||
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
|
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
|
||||||
|
|
||||||
def print_stats(self):
|
def print_stats(self):
|
||||||
print(f"[MultiNewsDataset stats]")
|
print(f"[MultiNewsDataset stats]")
|
||||||
# print(f" - langs: {self.langs}")
|
|
||||||
total_docs = 0
|
total_docs = 0
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
_len = len(self.lang_multiModalDataset[lang].data)
|
_len = len(self.lang_multiModalDataset[lang].data)
|
||||||
total_docs += _len
|
total_docs += _len
|
||||||
print(
|
print(
|
||||||
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].data)}"
|
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
|
||||||
)
|
)
|
||||||
print(f" - total docs: {total_docs}")
|
print(f" - total docs: {total_docs}")
|
||||||
|
|
||||||
def _count_lang_labels(self, data):
|
def _count_lang_labels(self, labels):
|
||||||
lang_labels = set()
|
lang_labels = set()
|
||||||
for sample in data:
|
for l in labels:
|
||||||
lang_labels.update(sample[-1])
|
lang_labels.update(l[-1])
|
||||||
return len(lang_labels)
|
return len(lang_labels)
|
||||||
|
|
||||||
def export_to_torch_dataset(self, tokenizer_id):
|
def export_to_torch_dataset(self, tokenizer_id):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
# torch_datasets = []
|
|
||||||
# for lang, multimodal_dataset in self.lang_multiModalDataset.keys():
|
|
||||||
# dataset = TorchMultiNewsDataset(
|
|
||||||
# lang=lang,
|
|
||||||
# data=multimodal_dataset.get_docs(),
|
|
||||||
# ids=multimodal_dataset.get_ids(),
|
|
||||||
# imgs=multimodal_dataset.get_imgs(),
|
|
||||||
# labels=multimodal_dataset.get_labels(),
|
|
||||||
# tokenizer_id=tokenizer_id,
|
|
||||||
# )
|
|
||||||
# torch_datasets.append(dataset)
|
|
||||||
|
|
||||||
# raise NotImplementedError
|
|
||||||
|
|
||||||
def save_to_disk(self):
|
def save_to_disk(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def training(self):
|
||||||
|
lXtr = {}
|
||||||
|
lYtr = {}
|
||||||
|
for lang, data in self.lang_multiModalDataset.items():
|
||||||
|
lXtr[lang] = data.data
|
||||||
|
lYtr[lang] = self.label_binarizer.transform(data.labels)
|
||||||
|
|
||||||
|
return lXtr, lYtr
|
||||||
|
|
||||||
|
def testing(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
all_labels = set()
|
||||||
|
for lang, data in self.lang_multiModalDataset.items():
|
||||||
|
for label in data.labels:
|
||||||
|
all_labels.update(label)
|
||||||
|
return all_labels
|
||||||
|
|
||||||
|
|
||||||
class MultiModalDataset:
|
class MultiModalDataset:
|
||||||
def __init__(self, lang, data_dir):
|
def __init__(self, lang, data_dir):
|
||||||
|
@ -83,10 +103,7 @@ class MultiModalDataset:
|
||||||
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
|
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
|
||||||
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
||||||
self.re_white = re.compile(r" +")
|
self.re_white = re.compile(r" +")
|
||||||
self.data = self.get_docs()
|
self.data, self.labels = self.get_docs()
|
||||||
|
|
||||||
def get_docs(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_imgs(self):
|
def get_imgs(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -98,19 +115,39 @@ class MultiModalDataset:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_docs(self):
|
def get_docs(self):
|
||||||
|
# FIXME: this is a mess
|
||||||
data = []
|
data = []
|
||||||
|
labels = []
|
||||||
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
|
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
|
||||||
for news_folder in news_folder:
|
for news_folder in news_folder:
|
||||||
if isdir(join(self.data_dir, news_folder)):
|
if isdir(join(self.data_dir, news_folder)):
|
||||||
fname_doc = f"text.{news_folder.split('.')[-1]}"
|
fname_doc = f"text.{news_folder.split('.')[-1]}"
|
||||||
with open(join(self.data_dir, news_folder, fname_doc)) as f:
|
with open(join(self.data_dir, news_folder, fname_doc)) as f:
|
||||||
html_doc = f.read()
|
html_doc = f.read()
|
||||||
img = self.get_image()
|
index_path = join(self.data_dir, news_folder, "index.html")
|
||||||
clean_doc, labels = self.preprocess_html(html_doc)
|
if ".jpg" not in listdir(join(self.data_dir, news_folder)):
|
||||||
data.append((fname_doc, clean_doc, html_doc, img, labels))
|
img_link, img = self.get_images(index_path)
|
||||||
return data
|
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
|
||||||
|
else:
|
||||||
|
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
|
||||||
|
clean_doc, doc_labels = self.preprocess_html(html_doc)
|
||||||
|
data.append((fname_doc, clean_doc, html_doc, img))
|
||||||
|
labels.append(doc_labels)
|
||||||
|
return data, labels
|
||||||
|
|
||||||
|
def save_img(self, path, img):
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(img)
|
||||||
|
|
||||||
|
def get_images(self, index_path):
|
||||||
|
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
|
||||||
|
imgs = imgs[1]
|
||||||
|
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
|
||||||
|
content = requests.get(imgs["src"]).content
|
||||||
|
return imgs, content
|
||||||
|
|
||||||
def preprocess_html(self, html_doc):
|
def preprocess_html(self, html_doc):
|
||||||
|
# TODO: this could be replaced by BeautifulSoup call or something similar
|
||||||
labels = self._extract_labels(html_doc)
|
labels = self._extract_labels(html_doc)
|
||||||
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
|
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
|
||||||
return cleaned, labels
|
return cleaned, labels
|
||||||
|
@ -130,13 +167,11 @@ class MultiModalDataset:
|
||||||
doc = doc.replace("\t", " ")
|
doc = doc.replace("\t", " ")
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def get_image(self):
|
|
||||||
# TODO: implement
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
|
|
||||||
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
||||||
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
||||||
|
lXtr, lYtr = dataset.training()
|
||||||
|
exit()
|
||||||
|
|
|
@ -9,7 +9,7 @@ class VanillaFunGen(ViewGen):
|
||||||
Sebastiani in DOI: https://doi.org/10.1145/3326065
|
Sebastiani in DOI: https://doi.org/10.1145/3326065
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1):
|
def __init__(self, base_learner, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
|
Init Posterior Probabilities embedder (i.e., VanillaFunGen)
|
||||||
:param base_learner: naive monolingual learners to be deployed as first-tier
|
:param base_learner: naive monolingual learners to be deployed as first-tier
|
||||||
|
@ -19,7 +19,6 @@ class VanillaFunGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
print("- init VanillaFun View Generating Function")
|
print("- init VanillaFun View Generating Function")
|
||||||
self.learners = base_learner
|
self.learners = base_learner
|
||||||
self.first_tier_parameters = first_tier_parameters
|
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.doc_projector = NaivePolylingualClassifier(
|
self.doc_projector = NaivePolylingualClassifier(
|
||||||
base_learner=self.learners,
|
base_learner=self.learners,
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
from vgfs.viewGen import ViewGen
|
||||||
|
|
||||||
|
|
||||||
|
class VisualGen(ViewGen):
|
||||||
|
def fit():
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def transform(self, lX):
|
||||||
|
return super().transform(lX)
|
||||||
|
|
||||||
|
def fit_transform(self, lX, lY):
|
||||||
|
return super().fit_transform(lX, lY)
|
||||||
|
|
||||||
|
def save_vgf(self, model_id):
|
||||||
|
return super().save_vgf(model_id)
|
||||||
|
|
||||||
|
def save_vgf(self, model_id):
|
||||||
|
return super().save_vgf(model_id)
|
Loading…
Reference in New Issue