diff --git a/LeQua2024/_lequa2024.py b/LeQua2024/_lequa2024.py index 5e414c5..285549d 100644 --- a/LeQua2024/_lequa2024.py +++ b/LeQua2024/_lequa2024.py @@ -1,6 +1,4 @@ -from typing import Tuple, Union import pandas as pd -import numpy as np import os from os.path import join @@ -63,9 +61,10 @@ def fetch_lequa2024(task, data_home='./data', merge_T3=False): val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt') val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn) - test_samples_path = join(lequa_dir, task, 'public', 'test_samples') - test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt') - test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn) + # test_samples_path = join(lequa_dir, task, 'public', 'test_samples') + # test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt') + # test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn) + test_gen = None if task != 'T3': tr_path = join(lequa_dir, task, 'public', 'training_data.txt') diff --git a/LeQua2024/baselines.py b/LeQua2024/baselines.py index 28a19f0..ff4c33c 100644 --- a/LeQua2024/baselines.py +++ b/LeQua2024/baselines.py @@ -82,25 +82,10 @@ def main(args): else: quantifier.fit(train) - - # valid_error = quantifier.best_score_ - - # test_err = qp.evaluation.evaluate(quantifier, protocol=gen_test, error_metric='mrae', verbose=True) - # print(f'method={q_name} got MRAE={test_err:.4f}') - # - # results.append((q_name, valid_error, test_err)) - - print(f'saving model in {model_path}') pickle.dump(quantifier, open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) - # print('\nResults') - # print('Method\tValid-err\ttest-err') - # for q_name, valid_error, test_err in results: - # print(f'{q_name}\t{valid_error:.4}\t{test_err:.4f}') - - if __name__ == '__main__': parser = argparse.ArgumentParser(description='LeQua2024 baselines') @@ -110,12 +95,4 @@ if __name__ == '__main__': help='Path of the directory containing LeQua 2024 data', default='./data') args = parser.parse_args() - # def assert_file(filename): - # if not os.path.exists(os.path.join(args.datadir, filename)): - # raise FileNotFoundError(f'path {args.datadir} does not contain "{filename}"') - # - # assert_file('dev_prevalences.txt') - # assert_file('training_data.txt') - # assert_file('dev_samples') - main(args)