diff --git a/Retrieval/commons.py b/Retrieval/commons.py index 79bf029..89cff39 100644 --- a/Retrieval/commons.py +++ b/Retrieval/commons.py @@ -45,6 +45,8 @@ def load_json_sample(path, class_name, max_lines=-1): obj = json.load(open(path, 'rt')) keys = [f'{id}' for id in range(len(obj['text'].keys()))] text = [obj['text'][id] for id in keys] + #print(list(obj.keys())) + #import sys; sys.exit(0) classes = [obj[class_name][id] for id in keys] if max_lines is not None and max_lines>0: text = text[:max_lines] diff --git a/Retrieval/fifth.py b/Retrieval/fifth.py index 1cdcdd0..acd7191 100644 --- a/Retrieval/fifth.py +++ b/Retrieval/fifth.py @@ -112,11 +112,11 @@ def scape_latex(string): Ks = [10, 50, 100, 250, 500, 1000, 2000] # Ks = [500] -for CLASS_NAME in ['continent'] : #'years_category']: #['continent', 'first_letter_category']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']: +for CLASS_NAME in ['gender_category'] : #'years_category']: #['continent', 'first_letter_category']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']: data_path = './' + CLASS_NAME - if CLASS_NAME in ['years_category', 'continent']: + if CLASS_NAME in ['years_category', 'continent', 'gender_category']: train_path = join(data_path, 'train500PerGroup.json') else: train_path = join(data_path, 'train3000samples.json')