cleaning
This commit is contained in:
parent
451c938171
commit
6c5bd674ea
|
@ -1,6 +1,4 @@
|
|||
from typing import Tuple, Union
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
|
@ -63,9 +61,10 @@ def fetch_lequa2024(task, data_home='./data', merge_T3=False):
|
|||
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
|
||||
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
|
||||
|
||||
test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
|
||||
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
|
||||
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
|
||||
# test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
|
||||
# test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
|
||||
# test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
|
||||
test_gen = None
|
||||
|
||||
if task != 'T3':
|
||||
tr_path = join(lequa_dir, task, 'public', 'training_data.txt')
|
||||
|
|
|
@ -82,25 +82,10 @@ def main(args):
|
|||
else:
|
||||
quantifier.fit(train)
|
||||
|
||||
|
||||
# valid_error = quantifier.best_score_
|
||||
|
||||
# test_err = qp.evaluation.evaluate(quantifier, protocol=gen_test, error_metric='mrae', verbose=True)
|
||||
# print(f'method={q_name} got MRAE={test_err:.4f}')
|
||||
#
|
||||
# results.append((q_name, valid_error, test_err))
|
||||
|
||||
|
||||
print(f'saving model in {model_path}')
|
||||
pickle.dump(quantifier, open(model_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
# print('\nResults')
|
||||
# print('Method\tValid-err\ttest-err')
|
||||
# for q_name, valid_error, test_err in results:
|
||||
# print(f'{q_name}\t{valid_error:.4}\t{test_err:.4f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='LeQua2024 baselines')
|
||||
|
@ -110,12 +95,4 @@ if __name__ == '__main__':
|
|||
help='Path of the directory containing LeQua 2024 data', default='./data')
|
||||
args = parser.parse_args()
|
||||
|
||||
# def assert_file(filename):
|
||||
# if not os.path.exists(os.path.join(args.datadir, filename)):
|
||||
# raise FileNotFoundError(f'path {args.datadir} does not contain "{filename}"')
|
||||
#
|
||||
# assert_file('dev_prevalences.txt')
|
||||
# assert_file('training_data.txt')
|
||||
# assert_file('dev_samples')
|
||||
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue