gFun/refactor/data/reader/rcv_reader.py

226 lines
7.8 KiB
Python

from zipfile import ZipFile
import xml.etree.ElementTree as ET
from data.languages import RCV2_LANGS_WITH_NLTK_STEMMING, RCV2_LANGS
from util.file import list_files
from sklearn.datasets import get_data_home
import gzip
from os.path import join, exists
from util.file import download_file_if_not_exists
import re
from collections import Counter
import numpy as np
import sys
"""
RCV2's Nomenclature:
ru = Russian
da = Danish
de = German
es = Spanish
lat = Spanish Latin-American (actually is also 'es' in the collection)
fr = French
it = Italian
nl = Dutch
pt = Portuguese
sv = Swedish
ja = Japanese
htw = Chinese
no = Norwegian
"""
RCV1_TOPICHIER_URL = "http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a02-orig-topics-hierarchy/rcv1.topics.hier.orig"
RCV1PROC_BASE_URL= 'http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a12-token-files'
RCV1_BASE_URL = "http://www.daviddlewis.com/resources/testcollections/rcv1/"
RCV2_BASE_URL = "http://trec.nist.gov/data/reuters/reuters.html"
rcv1_test_data_gz = ['lyrl2004_tokens_test_pt0.dat.gz',
'lyrl2004_tokens_test_pt1.dat.gz',
'lyrl2004_tokens_test_pt2.dat.gz',
'lyrl2004_tokens_test_pt3.dat.gz']
rcv1_train_data_gz = ['lyrl2004_tokens_train.dat.gz']
rcv1_doc_cats_data_gz = 'rcv1-v2.topics.qrels.gz'
RCV2_LANG_DIR = {'ru':'REUTE000',
'de':'REUTE00A',
'fr':'REUTE00B',
'sv':'REUTE001',
'no':'REUTE002',
'da':'REUTE003',
'pt':'REUTE004',
'it':'REUTE005',
'es':'REUTE006',
'lat':'REUTE007',
'jp':'REUTE008',
'htw':'REUTE009',
'nl':'REUTERS_'}
class RCV_Document:
def __init__(self, id, text, categories, date='', lang=None):
self.id = id
self.date = date
self.lang = lang
self.text = text
self.categories = categories
class ExpectedLanguageException(Exception): pass
class IDRangeException(Exception): pass
nwords = []
def parse_document(xml_content, assert_lang=None, valid_id_range=None):
root = ET.fromstring(xml_content)
if assert_lang:
if assert_lang not in root.attrib.values():
if assert_lang != 'jp' or 'ja' not in root.attrib.values(): # some documents are attributed to 'ja', others to 'jp'
raise ExpectedLanguageException('error: document of a different language')
doc_id = root.attrib['itemid']
if valid_id_range is not None:
if not valid_id_range[0] <= int(doc_id) <= valid_id_range[1]:
raise IDRangeException
doc_categories = [cat.attrib['code'] for cat in
root.findall('.//metadata/codes[@class="bip:topics:1.0"]/code')]
doc_date = root.attrib['date']
doc_title = root.find('.//title').text
doc_headline = root.find('.//headline').text
doc_body = '\n'.join([p.text for p in root.findall('.//text/p')])
if not doc_body:
raise ValueError('Empty document')
if doc_title is None: doc_title = ''
if doc_headline is None or doc_headline in doc_title: doc_headline = ''
text = '\n'.join([doc_title, doc_headline, doc_body]).strip()
text_length = len(text.split())
global nwords
nwords.append(text_length)
return RCV_Document(id=doc_id, text=text, categories=doc_categories, date=doc_date, lang=assert_lang)
def fetch_RCV1(data_path, split='all'):
assert split in ['train', 'test', 'all'], 'split should be "train", "test", or "all"'
request = []
labels = set()
read_documents = 0
lang = 'en'
training_documents = 23149
test_documents = 781265
if split == 'all':
split_range = (2286, 810596)
expected = training_documents+test_documents
elif split == 'train':
split_range = (2286, 26150)
expected = training_documents
else:
split_range = (26151, 810596)
expected = test_documents
global nwords
nwords=[]
for part in list_files(data_path):
if not re.match('\d+\.zip', part): continue
target_file = join(data_path, part)
assert exists(target_file), \
"You don't seem to have the file "+part+" in " + data_path + ", and the RCV1 corpus can not be downloaded"+\
" w/o a formal permission. Please, refer to " + RCV1_BASE_URL + " for more information."
zipfile = ZipFile(target_file)
for xmlfile in zipfile.namelist():
xmlcontent = zipfile.open(xmlfile).read()
try:
doc = parse_document(xmlcontent, assert_lang=lang, valid_id_range=split_range)
labels.update(doc.categories)
request.append(doc)
read_documents += 1
except ValueError:
print('\n\tskipping document {} with inconsistent language label: expected language {}'.format(part+'/'+xmlfile, lang))
except (IDRangeException, ExpectedLanguageException) as e:
pass
print('\r[{}] read {} documents'.format(part, len(request)), end='')
if read_documents == expected: break
if read_documents == expected: break
print()
print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
return request, list(labels)
def fetch_RCV2(data_path, languages=None):
if not languages:
languages = list(RCV2_LANG_DIR.keys())
else:
assert set(languages).issubset(set(RCV2_LANG_DIR.keys())), 'languages not in scope'
request = []
labels = set()
global nwords
nwords=[]
for lang in languages:
path = join(data_path, RCV2_LANG_DIR[lang])
lang_docs_read = 0
for part in list_files(path):
target_file = join(path, part)
assert exists(target_file), \
"You don't seem to have the file "+part+" in " + path + ", and the RCV2 corpus can not be downloaded"+\
" w/o a formal permission. Please, refer to " + RCV2_BASE_URL + " for more information."
zipfile = ZipFile(target_file)
for xmlfile in zipfile.namelist():
xmlcontent = zipfile.open(xmlfile).read()
try:
doc = parse_document(xmlcontent, assert_lang=lang)
labels.update(doc.categories)
request.append(doc)
lang_docs_read += 1
except ValueError:
print('\n\tskipping document {} with inconsistent language label: expected language {}'.format(RCV2_LANG_DIR[lang]+'/'+part+'/'+xmlfile, lang))
except (IDRangeException, ExpectedLanguageException) as e:
pass
print('\r[{}] read {} documents, {} for language {}'.format(RCV2_LANG_DIR[lang]+'/'+part, len(request), lang_docs_read, lang), end='')
print()
print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
return request, list(labels)
def fetch_topic_hierarchy(path, topics='all'):
assert topics in ['all', 'leaves']
download_file_if_not_exists(RCV1_TOPICHIER_URL, path)
hierarchy = {}
for line in open(path, 'rt'):
parts = line.strip().split()
parent,child = parts[1],parts[3]
if parent not in hierarchy:
hierarchy[parent]=[]
hierarchy[parent].append(child)
del hierarchy['None']
del hierarchy['Root']
print(hierarchy)
if topics=='all':
topics = set(hierarchy.keys())
for parent in hierarchy.keys():
topics.update(hierarchy[parent])
return list(topics)
elif topics=='leaves':
parents = set(hierarchy.keys())
childs = set()
for parent in hierarchy.keys():
childs.update(hierarchy[parent])
return list(childs.difference(parents))