fixing code to handle different categories
This commit is contained in:
parent
2a685cec1e
commit
8f9d19dd5f
|
@ -63,7 +63,7 @@ class TextRankings:
|
|||
O = self.obj
|
||||
docs_ids = [doc_id for doc_id, query_id in O['qid'].items() if query_id == sample_id]
|
||||
texts = [O['text'][doc_id] for doc_id in docs_ids]
|
||||
labels = [O['continent'][doc_id] for doc_id in docs_ids]
|
||||
labels = [O[self.class_name][doc_id] for doc_id in docs_ids]
|
||||
if max_lines > 0 and len(texts) > max_lines:
|
||||
ranks = [int(O['rank'][doc_id]) for doc_id in docs_ids]
|
||||
sel = np.argsort(ranks)[:max_lines]
|
||||
|
|
|
@ -104,26 +104,33 @@ RANK_AT_K = -1
|
|||
REDUCE_TR = 50000
|
||||
qp.environ['SAMPLE_SIZE'] = RANK_AT_K
|
||||
|
||||
data_path = './newExperimentalSetup'
|
||||
train_path = join(data_path, 'train3000samples.json')
|
||||
data_path = {
|
||||
'first_letter_category': './first_letter_categoryDataset',
|
||||
'continent': './newExperimentalSetup'
|
||||
}
|
||||
|
||||
def scape_latex(string):
|
||||
return string.replace('_', '\_')
|
||||
|
||||
|
||||
Ks = [10, 50, 100, 250, 500, 1000, 2000]
|
||||
# Ks = [500]
|
||||
|
||||
for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']:
|
||||
for CLASS_NAME in ['first_letter_category']: #['continent']: #, 'gender', 'gender_category', 'occupations', 'source_countries', 'source_subcont_regions', 'years_category', 'relative_pageviews_category']:
|
||||
|
||||
train_path = join(data_path[CLASS_NAME], 'train3000samples.json')
|
||||
|
||||
tfidf, classifier_trained = qp.util.pickled_resource(f'classifier_{CLASS_NAME}.pkl', train_classifier)
|
||||
trained=True
|
||||
|
||||
experiment_prot = RetrievedSamples(data_path,
|
||||
experiment_prot = RetrievedSamples(data_path[CLASS_NAME],
|
||||
load_fn=load_json_sample,
|
||||
vectorizer=tfidf,
|
||||
max_train_lines=None,
|
||||
max_test_lines=RANK_AT_K, classes=classifier_trained.classes_, class_name=CLASS_NAME)
|
||||
|
||||
method_names = [name for name, *other in methods()]
|
||||
benchmarks = [f'{CLASS_NAME}@{k}' for k in Ks]
|
||||
benchmarks = [f'{scape_latex(CLASS_NAME)}@{k}' for k in Ks]
|
||||
table_mae = Table(benchmarks, method_names, color_mode='global')
|
||||
table_mrae = Table(benchmarks, method_names, color_mode='global')
|
||||
|
||||
|
@ -158,8 +165,9 @@ for CLASS_NAME in ['continent']: #, 'gender', 'gender_category', 'occupations',
|
|||
pbar.set_description(f'{method_name}')
|
||||
|
||||
for k in Ks:
|
||||
table_mae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mae_errors[k])
|
||||
table_mrae.add(benchmark=f'{CLASS_NAME}@{k}', method=method_name, values=mrae_errors[k])
|
||||
|
||||
table_mae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mae_errors[k])
|
||||
table_mrae.add(benchmark=f'{scape_latex(CLASS_NAME)}@{k}', method=method_name, values=mrae_errors[k])
|
||||
|
||||
table_mae.latexPDF('./latex', 'table_mae.tex')
|
||||
table_mrae.latexPDF('./latex', 'table_mrae.tex')
|
||||
|
|
Loading…
Reference in New Issue