forked from moreo/QuaPy
adapting to the new format
This commit is contained in:
parent
4cd47cdf9f
commit
f63575ff55
|
@ -44,7 +44,7 @@ param_grid = {'C': np.logspace(-3,3,7), 'class_weight': ['balanced', None]}
|
|||
|
||||
|
||||
def gen_samples():
|
||||
return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_filename=False)
|
||||
return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_id=False)
|
||||
|
||||
|
||||
for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
|
||||
|
|
|
@ -37,32 +37,33 @@ def load_binary_vectors(path, nF=None):
|
|||
return sklearn.datasets.load_svmlight_file(path, n_features=nF)
|
||||
|
||||
|
||||
def __gen_load_samples_with_groudtruth(path_dir:str, return_filename:bool, ground_truth_path:str, load_fn, **load_kwargs):
|
||||
def __gen_load_samples_with_groudtruth(path_dir:str, return_id:bool, ground_truth_path:str, load_fn, **load_kwargs):
|
||||
true_prevs = ResultSubmission.load(ground_truth_path)
|
||||
for filename, prevalence in true_prevs.iterrows():
|
||||
sample, _ = load_fn(os.path.join(path_dir, filename), **load_kwargs)
|
||||
if return_filename:
|
||||
yield filename, sample, prevalence
|
||||
for id, prevalence in true_prevs.iterrows():
|
||||
sample, _ = load_fn(os.path.join(path_dir, f'{id}.txt'), **load_kwargs)
|
||||
if return_id:
|
||||
yield id, sample, prevalence
|
||||
else:
|
||||
yield sample, prevalence
|
||||
|
||||
|
||||
def __gen_load_samples_without_groudtruth(path_dir:str, return_filename:bool, load_fn, **load_kwargs):
|
||||
for filepath in glob(os.path.join(path_dir, '*_sample_*.txt')):
|
||||
sample, _ = load_fn(filepath, **load_kwargs)
|
||||
if return_filename:
|
||||
yield os.path.basename(filepath), sample
|
||||
def __gen_load_samples_without_groudtruth(path_dir:str, return_id:bool, load_fn, **load_kwargs):
|
||||
nsamples = len(glob(os.path.join(path_dir, '*.txt')))
|
||||
for id in range(nsamples):
|
||||
sample, _ = load_fn(os.path.join(path_dir, f'{id}.txt'), **load_kwargs)
|
||||
if return_id:
|
||||
yield id, sample
|
||||
else:
|
||||
yield sample
|
||||
|
||||
|
||||
def gen_load_samples_T1(path_dir:str, nF:int, ground_truth_path:str = None, return_filename=True):
|
||||
def gen_load_samples_T1(path_dir:str, nF:int, ground_truth_path:str = None, return_id=True):
|
||||
if ground_truth_path is None:
|
||||
# the generator function returns tuples (filename:str, sample:csr_matrix)
|
||||
gen_fn = __gen_load_samples_without_groudtruth(path_dir, return_filename, load_binary_vectors, nF=nF)
|
||||
gen_fn = __gen_load_samples_without_groudtruth(path_dir, return_id, load_binary_vectors, nF=nF)
|
||||
else:
|
||||
# the generator function returns tuples (filename:str, sample:csr_matrix, prevalence:ndarray)
|
||||
gen_fn = __gen_load_samples_with_groudtruth(path_dir, return_filename, ground_truth_path, load_binary_vectors, nF=nF)
|
||||
gen_fn = __gen_load_samples_with_groudtruth(path_dir, return_id, ground_truth_path, load_binary_vectors, nF=nF)
|
||||
for r in gen_fn:
|
||||
yield r
|
||||
|
||||
|
@ -83,47 +84,35 @@ class ResultSubmission:
|
|||
if not isinstance(categories, list) or len(categories) < 2:
|
||||
raise TypeError('wrong format for categories; a list with at least two category names (str) was expected')
|
||||
self.categories = categories
|
||||
self.df = pd.DataFrame(columns=['filename'] + list(categories))
|
||||
self.inferred_type = None
|
||||
self.df = pd.DataFrame(columns=list(categories))
|
||||
self.df.index.rename('id', inplace=True)
|
||||
|
||||
def add(self, sample_name:str, prevalence_values:np.ndarray):
|
||||
if not isinstance(sample_name, str):
|
||||
raise TypeError(f'error: expected str for sample_sample, found {type(sample_name)}')
|
||||
def add(self, sample_id:int, prevalence_values:np.ndarray):
|
||||
if not isinstance(sample_id, int):
|
||||
raise TypeError(f'error: expected int for sample_sample, found {type(sample_id)}')
|
||||
if not isinstance(prevalence_values, np.ndarray):
|
||||
raise TypeError(f'error: expected np.ndarray for prevalence_values, found {type(prevalence_values)}')
|
||||
|
||||
if self.inferred_type is None:
|
||||
if sample_name.startswith('test'):
|
||||
self.inferred_type = 'test'
|
||||
elif sample_name.startswith('dev'):
|
||||
self.inferred_type = 'dev'
|
||||
else:
|
||||
if not sample_name.startswith(self.inferred_type):
|
||||
raise ValueError(f'error: sample "{sample_name}" is not a valid entry for type "{self.inferred_type}"')
|
||||
|
||||
if not re.match("(test|dev)_sample_\d+\.txt", sample_name):
|
||||
raise ValueError(f'error: wrong format "{sample_name}"; right format is (test|dev)_sample_<number>.txt')
|
||||
if sample_name in self.df.filename.values:
|
||||
raise ValueError(f'error: prevalence values for "{sample_name}" already added')
|
||||
if sample_id in self.df.index.values:
|
||||
raise ValueError(f'error: prevalence values for "{sample_id}" already added')
|
||||
if prevalence_values.ndim!=1 and prevalence_values.size != len(self.categories):
|
||||
raise ValueError(f'error: wrong shape found for prevalence vector {prevalence_values}')
|
||||
if (prevalence_values<0).any() or (prevalence_values>1).any():
|
||||
raise ValueError(f'error: prevalence values out of range [0,1] for "{sample_name}"')
|
||||
raise ValueError(f'error: prevalence values out of range [0,1] for "{sample_id}"')
|
||||
if np.abs(prevalence_values.sum()-1) > constants.ERROR_TOL:
|
||||
raise ValueError(f'error: prevalence values do not sum up to one for "{sample_name}"'
|
||||
raise ValueError(f'error: prevalence values do not sum up to one for "{sample_id}"'
|
||||
f'(error tolerance {constants.ERROR_TOL})')
|
||||
|
||||
new_entry = dict([('filename',sample_name)]+[(col_i,prev_i) for col_i, prev_i in zip(self.categories, prevalence_values)])
|
||||
self.df = self.df.append(new_entry, ignore_index=True)
|
||||
# new_entry = dict([('id', sample_id)] + [(col_i, prev_i) for col_i, prev_i in enumerate(prevalence_values)])
|
||||
new_entry = pd.DataFrame(prevalence_values.reshape(1,2), index=[sample_id], columns=self.df.columns)
|
||||
self.df = self.df.append(new_entry, ignore_index=False)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'ResultSubmission':
|
||||
df, inferred_type = ResultSubmission.check_file_format(path, return_inferred_type=True)
|
||||
r = ResultSubmission(categories=df.columns.values[1:].tolist())
|
||||
r.inferred_type = inferred_type
|
||||
df = ResultSubmission.check_file_format(path)
|
||||
r = ResultSubmission(categories=df.columns.values.tolist())
|
||||
r.df = df
|
||||
return r
|
||||
|
||||
|
@ -140,59 +129,56 @@ class ResultSubmission:
|
|||
|
||||
def iterrows(self):
|
||||
for index, row in self.df.iterrows():
|
||||
filename = row.filename
|
||||
prevalence = row[self.df.columns[1]:].values.flatten()
|
||||
yield filename, prevalence
|
||||
# filename = row.filename
|
||||
prevalence = row.values.flatten()
|
||||
yield index, prevalence
|
||||
|
||||
@classmethod
|
||||
def check_file_format(cls, path, return_inferred_type=False) -> Union[pd.DataFrame, Tuple[pd.DataFrame, str]]:
|
||||
def check_file_format(cls, path) -> Union[pd.DataFrame, Tuple[pd.DataFrame, str]]:
|
||||
df = pd.read_csv(path, index_col=0)
|
||||
return ResultSubmission.check_dataframe_format(df, path=path, return_inferred_type=return_inferred_type)
|
||||
return ResultSubmission.check_dataframe_format(df, path=path)
|
||||
|
||||
@classmethod
|
||||
def check_dataframe_format(cls, df, path=None, return_inferred_type=False) -> Union[pd.DataFrame, Tuple[pd.DataFrame, str]]:
|
||||
def check_dataframe_format(cls, df, path=None) -> Union[pd.DataFrame, Tuple[pd.DataFrame, str]]:
|
||||
hint_path = '' # if given, show the data path in the error message
|
||||
if path is not None:
|
||||
hint_path = f' in {path}'
|
||||
|
||||
if 'filename' not in df.columns or len(df.columns) < 3:
|
||||
raise ValueError(f'wrong header{hint_path}, the format of the header should be ",filename,<cat_1>,...,<cat_n>"')
|
||||
if df.index.name != 'id' or len(df.columns) < 2:
|
||||
raise ValueError(f'wrong header{hint_path}, '
|
||||
f'the format of the header should be "id,<cat_1>,...,<cat_n>"')
|
||||
if [int(ci) for ci in df.columns.values] != list(range(len(df.columns))):
|
||||
raise ValueError(f'wrong header{hint_path}, category ids should be 0,1,2,...,n')
|
||||
|
||||
if df.empty:
|
||||
raise ValueError(f'error{hint_path}: results file is empty')
|
||||
elif len(df) == constants.DEV_SAMPLES:
|
||||
inferred_type = 'dev'
|
||||
expected_len = constants.DEV_SAMPLES
|
||||
elif len(df) == constants.TEST_SAMPLES:
|
||||
inferred_type = 'test'
|
||||
expected_len = constants.TEST_SAMPLES
|
||||
else:
|
||||
elif len(df) != constants.DEV_SAMPLES and len(df) != constants.TEST_SAMPLES:
|
||||
raise ValueError(f'wrong number of prevalence values found{hint_path}; '
|
||||
f'expected {constants.DEV_SAMPLES} for development sets and '
|
||||
f'{constants.TEST_SAMPLES} for test sets; found {len(df)}')
|
||||
|
||||
set_names = frozenset(df.filename)
|
||||
for i in range(expected_len):
|
||||
if f'{inferred_type}_sample_{i}.txt' not in set_names:
|
||||
raise ValueError(f'error{hint_path} a file with {len(df)} entries is assumed to be of type '
|
||||
f'"{inferred_type}" but entry {inferred_type}_sample_{i}.txt is missing '
|
||||
f'(among perhaps many others)')
|
||||
ids = set(df.index.values)
|
||||
expected_ids = set(range(len(df)))
|
||||
if ids != expected_ids:
|
||||
missing = expected_ids - ids
|
||||
if missing:
|
||||
raise ValueError(f'there are {len(missing)} missing ids{hint_path}: {sorted(missing)}')
|
||||
unexpected = ids - expected_ids
|
||||
if unexpected:
|
||||
raise ValueError(f'there are {len(missing)} unexpected ids{hint_path}: {sorted(unexpected)}')
|
||||
|
||||
for category_name in df.columns[1:]:
|
||||
for category_name in df.columns:
|
||||
if (df[category_name] < 0).any() or (df[category_name] > 1).any():
|
||||
raise ValueError(f'error{hint_path} column "{category_name}" contains values out of range [0,1]')
|
||||
|
||||
prevs = df.loc[:, df.columns[1]:].values
|
||||
prevs = df.values
|
||||
round_errors = np.abs(prevs.sum(axis=-1) - 1.) > constants.ERROR_TOL
|
||||
if round_errors.any():
|
||||
raise ValueError(f'warning: prevalence values in rows with id {np.where(round_errors)[0].tolist()} '
|
||||
f'do not sum up to 1 (error tolerance {constants.ERROR_TOL}), '
|
||||
f'probably due to some rounding errors.')
|
||||
|
||||
if return_inferred_type:
|
||||
return df, inferred_type
|
||||
else:
|
||||
return df
|
||||
return df
|
||||
|
||||
def sort_categories(self):
|
||||
self.df = self.df.reindex([self.df.columns[0]] + sorted(self.df.columns[1:]), axis=1)
|
||||
|
|
|
@ -20,7 +20,7 @@ def main(args):
|
|||
|
||||
if __name__=='__main__':
|
||||
parser = argparse.ArgumentParser(description='LeQua2022 official format-checker script')
|
||||
parser.add_argument('prevalence_file', metavar='PREV-PATH', type=str,
|
||||
parser.add_argument('prevalence_file', metavar='PREVALENCEFILE-PATH', type=str,
|
||||
help='Path of the file containing prevalence values to check')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -28,10 +28,10 @@ def main(args):
|
|||
model = pickle.load(open(args.model, 'rb'))
|
||||
|
||||
# predictions
|
||||
predictions = ResultSubmission(categories=categories)
|
||||
for samplename, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
|
||||
predictions = ResultSubmission(categories=list(range(len(categories))))
|
||||
for sampleid, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
|
||||
desc='predicting', total=nsamples):
|
||||
predictions.add(samplename, model.quantify(sample))
|
||||
predictions.add(sampleid, model.quantify(sample))
|
||||
|
||||
# saving
|
||||
basedir = os.path.basename(args.output)
|
||||
|
|
Loading…
Reference in New Issue