226 lines
7.8 KiB
Python
226 lines
7.8 KiB
Python
from zipfile import ZipFile
|
|
import xml.etree.ElementTree as ET
|
|
from data.languages import RCV2_LANGS_WITH_NLTK_STEMMING, RCV2_LANGS
|
|
from util.file import list_files
|
|
from sklearn.datasets import get_data_home
|
|
import gzip
|
|
from os.path import join, exists
|
|
from util.file import download_file_if_not_exists
|
|
import re
|
|
from collections import Counter
|
|
import numpy as np
|
|
import sys
|
|
|
|
"""
|
|
RCV2's Nomenclature:
|
|
ru = Russian
|
|
da = Danish
|
|
de = German
|
|
es = Spanish
|
|
lat = Spanish Latin-American (actually is also 'es' in the collection)
|
|
fr = French
|
|
it = Italian
|
|
nl = Dutch
|
|
pt = Portuguese
|
|
sv = Swedish
|
|
ja = Japanese
|
|
htw = Chinese
|
|
no = Norwegian
|
|
"""
|
|
|
|
RCV1_TOPICHIER_URL = "http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a02-orig-topics-hierarchy/rcv1.topics.hier.orig"
|
|
RCV1PROC_BASE_URL= 'http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a12-token-files'
|
|
RCV1_BASE_URL = "http://www.daviddlewis.com/resources/testcollections/rcv1/"
|
|
RCV2_BASE_URL = "http://trec.nist.gov/data/reuters/reuters.html"
|
|
|
|
rcv1_test_data_gz = ['lyrl2004_tokens_test_pt0.dat.gz',
|
|
'lyrl2004_tokens_test_pt1.dat.gz',
|
|
'lyrl2004_tokens_test_pt2.dat.gz',
|
|
'lyrl2004_tokens_test_pt3.dat.gz']
|
|
|
|
rcv1_train_data_gz = ['lyrl2004_tokens_train.dat.gz']
|
|
|
|
rcv1_doc_cats_data_gz = 'rcv1-v2.topics.qrels.gz'
|
|
|
|
RCV2_LANG_DIR = {'ru':'REUTE000',
|
|
'de':'REUTE00A',
|
|
'fr':'REUTE00B',
|
|
'sv':'REUTE001',
|
|
'no':'REUTE002',
|
|
'da':'REUTE003',
|
|
'pt':'REUTE004',
|
|
'it':'REUTE005',
|
|
'es':'REUTE006',
|
|
'lat':'REUTE007',
|
|
'jp':'REUTE008',
|
|
'htw':'REUTE009',
|
|
'nl':'REUTERS_'}
|
|
|
|
|
|
class RCV_Document:
|
|
|
|
def __init__(self, id, text, categories, date='', lang=None):
|
|
self.id = id
|
|
self.date = date
|
|
self.lang = lang
|
|
self.text = text
|
|
self.categories = categories
|
|
|
|
|
|
class ExpectedLanguageException(Exception): pass
|
|
class IDRangeException(Exception): pass
|
|
|
|
|
|
nwords = []
|
|
|
|
def parse_document(xml_content, assert_lang=None, valid_id_range=None):
|
|
root = ET.fromstring(xml_content)
|
|
if assert_lang:
|
|
if assert_lang not in root.attrib.values():
|
|
if assert_lang != 'jp' or 'ja' not in root.attrib.values(): # some documents are attributed to 'ja', others to 'jp'
|
|
raise ExpectedLanguageException('error: document of a different language')
|
|
|
|
doc_id = root.attrib['itemid']
|
|
if valid_id_range is not None:
|
|
if not valid_id_range[0] <= int(doc_id) <= valid_id_range[1]:
|
|
raise IDRangeException
|
|
|
|
doc_categories = [cat.attrib['code'] for cat in
|
|
root.findall('.//metadata/codes[@class="bip:topics:1.0"]/code')]
|
|
|
|
doc_date = root.attrib['date']
|
|
doc_title = root.find('.//title').text
|
|
doc_headline = root.find('.//headline').text
|
|
doc_body = '\n'.join([p.text for p in root.findall('.//text/p')])
|
|
|
|
if not doc_body:
|
|
raise ValueError('Empty document')
|
|
|
|
if doc_title is None: doc_title = ''
|
|
if doc_headline is None or doc_headline in doc_title: doc_headline = ''
|
|
text = '\n'.join([doc_title, doc_headline, doc_body]).strip()
|
|
|
|
text_length = len(text.split())
|
|
global nwords
|
|
nwords.append(text_length)
|
|
|
|
return RCV_Document(id=doc_id, text=text, categories=doc_categories, date=doc_date, lang=assert_lang)
|
|
|
|
|
|
def fetch_RCV1(data_path, split='all'):
|
|
|
|
assert split in ['train', 'test', 'all'], 'split should be "train", "test", or "all"'
|
|
|
|
request = []
|
|
labels = set()
|
|
read_documents = 0
|
|
lang = 'en'
|
|
|
|
training_documents = 23149
|
|
test_documents = 781265
|
|
|
|
if split == 'all':
|
|
split_range = (2286, 810596)
|
|
expected = training_documents+test_documents
|
|
elif split == 'train':
|
|
split_range = (2286, 26150)
|
|
expected = training_documents
|
|
else:
|
|
split_range = (26151, 810596)
|
|
expected = test_documents
|
|
|
|
global nwords
|
|
nwords=[]
|
|
for part in list_files(data_path):
|
|
if not re.match('\d+\.zip', part): continue
|
|
target_file = join(data_path, part)
|
|
assert exists(target_file), \
|
|
"You don't seem to have the file "+part+" in " + data_path + ", and the RCV1 corpus can not be downloaded"+\
|
|
" w/o a formal permission. Please, refer to " + RCV1_BASE_URL + " for more information."
|
|
zipfile = ZipFile(target_file)
|
|
for xmlfile in zipfile.namelist():
|
|
xmlcontent = zipfile.open(xmlfile).read()
|
|
try:
|
|
doc = parse_document(xmlcontent, assert_lang=lang, valid_id_range=split_range)
|
|
labels.update(doc.categories)
|
|
request.append(doc)
|
|
read_documents += 1
|
|
except ValueError:
|
|
print('\n\tskipping document {} with inconsistent language label: expected language {}'.format(part+'/'+xmlfile, lang))
|
|
except (IDRangeException, ExpectedLanguageException) as e:
|
|
pass
|
|
print('\r[{}] read {} documents'.format(part, len(request)), end='')
|
|
if read_documents == expected: break
|
|
if read_documents == expected: break
|
|
print()
|
|
print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
|
|
return request, list(labels)
|
|
|
|
|
|
def fetch_RCV2(data_path, languages=None):
|
|
|
|
if not languages:
|
|
languages = list(RCV2_LANG_DIR.keys())
|
|
else:
|
|
assert set(languages).issubset(set(RCV2_LANG_DIR.keys())), 'languages not in scope'
|
|
|
|
request = []
|
|
labels = set()
|
|
global nwords
|
|
nwords=[]
|
|
for lang in languages:
|
|
path = join(data_path, RCV2_LANG_DIR[lang])
|
|
lang_docs_read = 0
|
|
for part in list_files(path):
|
|
target_file = join(path, part)
|
|
assert exists(target_file), \
|
|
"You don't seem to have the file "+part+" in " + path + ", and the RCV2 corpus can not be downloaded"+\
|
|
" w/o a formal permission. Please, refer to " + RCV2_BASE_URL + " for more information."
|
|
zipfile = ZipFile(target_file)
|
|
for xmlfile in zipfile.namelist():
|
|
xmlcontent = zipfile.open(xmlfile).read()
|
|
try:
|
|
doc = parse_document(xmlcontent, assert_lang=lang)
|
|
labels.update(doc.categories)
|
|
request.append(doc)
|
|
lang_docs_read += 1
|
|
except ValueError:
|
|
print('\n\tskipping document {} with inconsistent language label: expected language {}'.format(RCV2_LANG_DIR[lang]+'/'+part+'/'+xmlfile, lang))
|
|
except (IDRangeException, ExpectedLanguageException) as e:
|
|
pass
|
|
print('\r[{}] read {} documents, {} for language {}'.format(RCV2_LANG_DIR[lang]+'/'+part, len(request), lang_docs_read, lang), end='')
|
|
print()
|
|
print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
|
|
return request, list(labels)
|
|
|
|
|
|
def fetch_topic_hierarchy(path, topics='all'):
|
|
assert topics in ['all', 'leaves']
|
|
|
|
download_file_if_not_exists(RCV1_TOPICHIER_URL, path)
|
|
hierarchy = {}
|
|
for line in open(path, 'rt'):
|
|
parts = line.strip().split()
|
|
parent,child = parts[1],parts[3]
|
|
if parent not in hierarchy:
|
|
hierarchy[parent]=[]
|
|
hierarchy[parent].append(child)
|
|
|
|
del hierarchy['None']
|
|
del hierarchy['Root']
|
|
print(hierarchy)
|
|
|
|
if topics=='all':
|
|
topics = set(hierarchy.keys())
|
|
for parent in hierarchy.keys():
|
|
topics.update(hierarchy[parent])
|
|
return list(topics)
|
|
elif topics=='leaves':
|
|
parents = set(hierarchy.keys())
|
|
childs = set()
|
|
for parent in hierarchy.keys():
|
|
childs.update(hierarchy[parent])
|
|
return list(childs.difference(parents))
|
|
|
|
|