This commit is contained in:
Alejandro Moreo Fernandez 2024-02-14 18:54:07 +01:00
parent 451c938171
commit 6c5bd674ea
2 changed files with 4 additions and 28 deletions

View File

@ -1,6 +1,4 @@
from typing import Tuple, Union
import pandas as pd
import numpy as np
import os
from os.path import join
@ -63,9 +61,10 @@ def fetch_lequa2024(task, data_home='./data', merge_T3=False):
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
# test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
# test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
# test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
test_gen = None
if task != 'T3':
tr_path = join(lequa_dir, task, 'public', 'training_data.txt')

View File

@ -82,25 +82,10 @@ def main(args):
else:
quantifier.fit(train)
# valid_error = quantifier.best_score_
# test_err = qp.evaluation.evaluate(quantifier, protocol=gen_test, error_metric='mrae', verbose=True)
# print(f'method={q_name} got MRAE={test_err:.4f}')
#
# results.append((q_name, valid_error, test_err))
print(f'saving model in {model_path}')
pickle.dump(quantifier, open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
# print('\nResults')
# print('Method\tValid-err\ttest-err')
# for q_name, valid_error, test_err in results:
# print(f'{q_name}\t{valid_error:.4}\t{test_err:.4f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='LeQua2024 baselines')
@ -110,12 +95,4 @@ if __name__ == '__main__':
help='Path of the directory containing LeQua 2024 data', default='./data')
args = parser.parse_args()
# def assert_file(filename):
# if not os.path.exists(os.path.join(args.datadir, filename)):
# raise FileNotFoundError(f'path {args.datadir} does not contain "{filename}"')
#
# assert_file('dev_prevalences.txt')
# assert_file('training_data.txt')
# assert_file('dev_samples')
main(args)