1
0
Fork 0
QuaPy/LeQua2022/predict.py

61 lines
2.1 KiB
Python
Raw Normal View History

2021-10-28 15:54:27 +02:00
import argparse
import quapy as qp
from data import ResultSubmission, evaluate_submission
import constants
import os
import pickle
from tqdm import tqdm
from data import gen_load_samples_T1, load_category_map
from glob import glob
import constants
"""
LeQua2022 prediction script
"""
def main(args):
# check the number of samples
nsamples = len(glob(os.path.join(args.samples, '*.txt')))
if nsamples not in {constants.DEV_SAMPLES, constants.TEST_SAMPLES}:
print(f'Warning: The number of samples does neither coincide with the expected number of '
f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
f'test samples ({constants.TEST_SAMPLES}).')
2021-11-08 18:01:49 +01:00
# _, categories = load_category_map(args.catmap)
2021-10-28 15:54:27 +02:00
# load pickled model
model = pickle.load(open(args.model, 'rb'))
# predictions
2021-11-08 18:01:49 +01:00
predictions = ResultSubmission()
2021-11-04 19:15:16 +01:00
for sampleid, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
2021-10-28 15:54:27 +02:00
desc='predicting', total=nsamples):
2021-11-04 19:15:16 +01:00
predictions.add(sampleid, model.quantify(sample))
2021-10-28 15:54:27 +02:00
# saving
basedir = os.path.basename(args.output)
if basedir:
os.makedirs(basedir, exist_ok=True)
predictions.dump(args.output)
if __name__=='__main__':
parser = argparse.ArgumentParser(description='LeQua2022 prediction script')
parser.add_argument('model', metavar='MODEL-PATH', type=str,
help='Path of saved model')
parser.add_argument('samples', metavar='SAMPLES-PATH', type=str,
help='Path to the directory containing the samples')
parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
help='Path where to store the predictions file')
parser.add_argument('nf', metavar='NUM-FEATURES', type=int,
help='Number of features seen during training')
args = parser.parse_args()
if not os.path.exists(args.samples):
raise FileNotFoundError(f'path {args.samples} does not exist')
if not os.path.isdir(args.samples):
raise ValueError(f'path {args.samples} is not a valid directory')
main(args)