2021-10-28 15:54:27 +02:00
|
|
|
import argparse
|
|
|
|
import quapy as qp
|
2021-11-09 15:44:57 +01:00
|
|
|
from data import ResultSubmission
|
2021-10-28 15:54:27 +02:00
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
from tqdm import tqdm
|
2021-11-24 11:20:42 +01:00
|
|
|
from data import gen_load_samples
|
2021-10-28 15:54:27 +02:00
|
|
|
from glob import glob
|
|
|
|
import constants
|
|
|
|
|
|
|
|
"""
|
|
|
|
LeQua2022 prediction script
|
|
|
|
"""
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
|
|
|
# check the number of samples
|
|
|
|
nsamples = len(glob(os.path.join(args.samples, '*.txt')))
|
|
|
|
if nsamples not in {constants.DEV_SAMPLES, constants.TEST_SAMPLES}:
|
|
|
|
print(f'Warning: The number of samples does neither coincide with the expected number of '
|
|
|
|
f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
|
|
|
|
f'test samples ({constants.TEST_SAMPLES}).')
|
|
|
|
|
|
|
|
# load pickled model
|
|
|
|
model = pickle.load(open(args.model, 'rb'))
|
|
|
|
|
|
|
|
# predictions
|
2021-11-08 18:01:49 +01:00
|
|
|
predictions = ResultSubmission()
|
2021-11-30 11:36:23 +01:00
|
|
|
for sampleid, sample in tqdm(gen_load_samples(args.samples, return_id=True, load_fn=), desc='predicting', total=nsamples):
|
2021-11-04 19:15:16 +01:00
|
|
|
predictions.add(sampleid, model.quantify(sample))
|
2021-10-28 15:54:27 +02:00
|
|
|
|
|
|
|
# saving
|
2021-11-09 15:44:57 +01:00
|
|
|
qp.util.create_parent_dir(args.output)
|
2021-10-28 15:54:27 +02:00
|
|
|
predictions.dump(args.output)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=='__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='LeQua2022 prediction script')
|
|
|
|
parser.add_argument('model', metavar='MODEL-PATH', type=str,
|
|
|
|
help='Path of saved model')
|
|
|
|
parser.add_argument('samples', metavar='SAMPLES-PATH', type=str,
|
|
|
|
help='Path to the directory containing the samples')
|
|
|
|
parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
|
|
|
|
help='Path where to store the predictions file')
|
|
|
|
parser.add_argument('nf', metavar='NUM-FEATURES', type=int,
|
|
|
|
help='Number of features seen during training')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if not os.path.exists(args.samples):
|
|
|
|
raise FileNotFoundError(f'path {args.samples} does not exist')
|
|
|
|
if not os.path.isdir(args.samples):
|
|
|
|
raise ValueError(f'path {args.samples} is not a valid directory')
|
|
|
|
|
|
|
|
main(args)
|