Compare commits

...

2 Commits

2 changed files with 18 additions and 10 deletions

View File

@ -63,7 +63,7 @@ class TextRankings:
O = self.obj
docs_ids = [doc_id for doc_id, query_id in O['qid'].items() if query_id == sample_id]
texts = [O['text'][doc_id] for doc_id in docs_ids]
labels = [O['continent'][doc_id] for doc_id in docs_ids]
labels = [O[self.class_name][doc_id] for doc_id in docs_ids]
if max_lines > 0 and len(texts) > max_lines:
ranks = [int(O['rank'][doc_id]) for doc_id in docs_ids]
sel = np.argsort(ranks)[:max_lines]

View File

@ -104,26 +104,33 @@ RANK_AT_K = -1
REDUCE_TR = 50000
qp.environ['SAMPLE_SIZE'] = RANK_AT_K
data_path = './newExperimentalSetup'
train_path = join(data_path, 'train3000samples.json')
data_path = {
'first_letter_category': './first_letter_categoryDataset',
'continent': './newExperimentalSetup'
}
def scape_latex(string):
return string.replace('_', '\_')
Ks = [10, 50, 100, 250, 500, 1000, 2000]
# Ks = [500]
for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']:
for CLASS_NAME in ['first_letter_category']: #['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']:
train_path = join(data_path[CLASS_NAME], 'train3000samples.json')
tfidf, classifier_trained = qp.util.pickled_resource(f'classifier_{CLASS_NAME}.pkl', train_classifier)
trained=True
experiment_prot = RetrievedSamples(data_path,
experiment_prot = RetrievedSamples(data_path[CLASS_NAME],
load_fn=load_json_sample,
vectorizer=tfidf,
max_train_lines=None,
max_test_lines=RANK_AT_K, classes=classifier_trained.classes_, class_name=CLASS_NAME)
method_names = [name for name, *other in methods()]
benchmarks = [f'{CLASS_NAME}@{k}' for k in Ks]
benchmarks = [f'{scape_latex(CLASS_NAME)}@{k}' for k in Ks]
table_mae = Table(benchmarks, method_names, color_mode='global')
table_mrae = Table(benchmarks, method_names, color_mode='global')
@ -158,11 +165,12 @@ for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations',
pbar.set_description(f'{method_name}')
for k in Ks:
table_mae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mae_errors[k])
table_mrae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mrae_errors[k])
table_mae.latexPDF('./latex', 'table_mae.tex')
table_mrae.latexPDF('./latex', 'table_mrae.tex')
table_mae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mae_errors[k])
table_mrae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mrae_errors[k])
table_mae.latexPDF('./latex', f'table_{CLASS_NAME}_mae.tex')
table_mrae.latexPDF('./latex', f'table_{CLASS_NAME}_mrae.tex')