cleaning and refactoring, also trying to repair the montecarlo approximation

This commit is contained in:
Alejandro Moreo Fernandez 2023-09-28 17:47:39 +02:00
parent 9221b3f775
commit 56dbe744df
10 changed files with 355 additions and 197 deletions

1
.gitignore vendored
View File

@ -164,3 +164,4 @@ dmypy.json
.pyre/ .pyre/
*__pycache__* *__pycache__*
*.dataframe

View File

@ -0,0 +1,83 @@
import pickle
import os
from distribution_matching.commons import BIN_METHODS, new_method, show_results
import quapy as qp
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
SEED = 1
if __name__ == '__main__':
qp.environ['SAMPLE_SIZE'] = 100
qp.environ['N_JOBS'] = -1
n_bags_val = 250
n_bags_test = 1000
optim = 'mae'
result_dir = f'results/binary/{optim}'
os.makedirs(result_dir, exist_ok=True)
for method in BIN_METHODS:
print('Init method', method)
global_result_path = f'{result_dir}/{method}'
if not os.path.exists(global_result_path + '.csv'):
with open(global_result_path + '.csv', 'wt') as csv:
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
with open(global_result_path + '.csv', 'at') as csv:
for dataset in qp.datasets.UCI_DATASETS:
if dataset in ['acute.a', 'acute.b', 'iris.1']: continue # , 'pageblocks.5', 'spambase', 'wdbc']: continue
print('init', dataset)
local_result_path = global_result_path + '_' + dataset
if os.path.exists(local_result_path + '.dataframe'):
print(f'result file {local_result_path}.dataframe already exist; skipping')
continue
with qp.util.temp_seed(SEED):
param_grid, quantifier = new_method(method, max_iter=3000)
data = qp.datasets.fetch_UCIDataset(dataset)
# model selection
train, test = data.train_test
train, val = train.split_stratified()
protocol = UPP(val, repeats=n_bags_val)
modsel = GridSearchQ(
quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=1, error=optim
)
try:
modsel.fit(train)
print(f'best params {modsel.best_params_}')
print(f'best score {modsel.best_score_}')
pickle.dump(
(modsel.best_params_, modsel.best_score_,),
open(f'{local_result_path}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
quantifier = modsel.best_model()
except:
print('something went wrong... reporting CC')
quantifier = qp.method.aggregative.CC(LR()).fit(train)
protocol = UPP(test, repeats=n_bags_test)
report = qp.evaluation.evaluation_report(quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'],
verbose=True)
report.to_csv(f'{local_result_path}.dataframe')
means = report.mean()
csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
csv.flush()
show_results(global_result_path)

View File

@ -0,0 +1,78 @@
import numpy as np
import pandas as pd
from distribution_matching.method_kdey import KDEy
from quapy.method.aggregative import EMQ, CC, PCC, DistributionMatching, PACC, HDy, OneVsAllAggregative, ACC
from distribution_matching.method_dirichlety import DIRy
from sklearn.linear_model import LogisticRegression
METHODS = ['ACC', 'PACC', 'HDy-OvA', 'DIR', 'DM', 'KDEy-DM', 'EMQ', 'KDEy-ML']
BIN_METHODS = ['ACC', 'PACC', 'HDy', 'DIR', 'DM', 'KDEy-DM', 'EMQ', 'KDEy-ML']
hyper_LR = {
'classifier__C': np.logspace(-3,3,7),
'classifier__class_weight': ['balanced', None]
}
def new_method(method, **lr_kwargs):
lr = LogisticRegression(**lr_kwargs)
if method == 'CC':
param_grid = hyper_LR
quantifier = CC(lr)
elif method == 'PCC':
param_grid = hyper_LR
quantifier = PCC(lr)
elif method == 'ACC':
param_grid = hyper_LR
quantifier = ACC(lr)
elif method == 'PACC':
param_grid = hyper_LR
quantifier = PACC(lr)
elif method == 'KDEy-ML':
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(lr, target='max_likelihood', val_split=10)
elif method in ['KDEy-DM']:
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(lr, target='min_divergence', divergence='l2', montecarlo_trials=5000, val_split=10)
elif method == 'DIR':
param_grid = hyper_LR
quantifier = DIRy(lr)
elif method == 'EMQ':
param_grid = hyper_LR
quantifier = EMQ(lr)
elif method == 'HDy-OvA':
param_grid = {'binary_quantifier__' + key: val for key, val in hyper_LR.items()}
quantifier = OneVsAllAggregative(HDy(lr))
elif method == 'DM':
method_params = {
'nbins': [4,8,16,32],
'val_split': [10, 0.4],
'divergence': ['HD', 'topsoe', 'l2']
}
param_grid = {**method_params, **hyper_LR}
quantifier = DistributionMatching(lr)
elif method in ['KDE-DMkld']:
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(lr, target='min_divergence', divergence='KLD', montecarlo_trials=5000, val_split=10)
else:
raise NotImplementedError('unknown method', method)
return param_grid, quantifier
def show_results(result_path):
df = pd.read_csv(result_path+'.csv', sep='\t')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
print(pv)

View File

@ -0,0 +1,26 @@
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
"""
This script generates plots of sensibility for the hyperparameter "bandwidth".
Plots results for MAE, MRAE, and KLD
The rest of hyperparameters were set to their default values
"""
df_tweet = pd.read_csv('../results_tweet_sensibility/KDEy-MLE.csv', sep='\t')
df_lequa = pd.read_csv('../results_lequa_sensibility/KDEy-MLE.csv', sep='\t')
df = pd.concat([df_tweet, df_lequa])
for err in ['MAE', 'MRAE', 'KLD']:
piv = df.pivot_table(index='Bandwidth', columns='Dataset', values=err)
g = sns.lineplot(data=piv, markers=True, dashes=False)
g.set(xlim=(0.01, 0.2))
g.legend(loc="center left", bbox_to_anchor=(1, 0.5))
g.set_ylabel(err)
g.set_xticks(np.linspace(0.01, 0.2, 20))
plt.xticks(rotation=90)
plt.grid()
plt.savefig(f'./sensibility_{err}.pdf', bbox_inches='tight')
plt.clf()

View File

@ -0,0 +1,56 @@
import numpy as np
from sklearn.linear_model import LogisticRegression
import os
import pandas as pd
import quapy as qp
from method_kdey import KDEy
SEED=1
def task(bandwidth):
print('job-init', dataset, bandwidth)
train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset)
with qp.util.temp_seed(SEED):
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10, bandwidth=bandwidth)
quantifier.fit(train)
report = qp.evaluation.evaluation_report(
quantifier, protocol=test_gen, error_metrics=['mae', 'mrae', 'kld'], verbose=True)
return report
if __name__ == '__main__':
qp.environ['SAMPLE_SIZE'] = qp.datasets.LEQUA2022_SAMPLE_SIZE['T1B']
qp.environ['N_JOBS'] = -1
result_dir = f'results_lequa_sensibility'
os.makedirs(result_dir, exist_ok=True)
method = 'KDEy-MLE'
global_result_path = f'{result_dir}/{method}'
if not os.path.exists(global_result_path+'.csv'):
with open(global_result_path+'.csv', 'wt') as csv:
csv.write(f'Method\tDataset\tBandwidth\tMAE\tMRAE\tKLD\n')
dataset = 'T1B'
bandwidths = np.linspace(0.01, 0.2, 20)
reports = qp.util.parallel(task, bandwidths, n_jobs=-1)
with open(global_result_path + '.csv', 'at') as csv:
for bandwidth, report in zip(bandwidths, reports):
means = report.mean()
local_result_path = global_result_path + '_' + dataset + f'_{bandwidth:.3f}'
report.to_csv(f'{local_result_path}.dataframe')
csv.write(f'{method}\tLeQua-T1B\t{bandwidth}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
csv.flush()
df = pd.read_csv(global_result_path + '.csv', sep='\t')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
print(pv)

View File

@ -1,16 +1,12 @@
import pickle import pickle
import numpy as np import numpy as np
from sklearn.linear_model import LogisticRegression
import os import os
import sys
import pandas as pd import pandas as pd
from distribution_matching.commons import METHODS, new_method, show_results
import quapy as qp import quapy as qp
from quapy.method.aggregative import EMQ, DistributionMatching, PACC, HDy, OneVsAllAggregative, ACC
from method_kdey import KDEy
from method_dirichlety import DIRy
from quapy.model_selection import GridSearchQ from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP
if __name__ == '__main__': if __name__ == '__main__':
@ -18,16 +14,11 @@ if __name__ == '__main__':
qp.environ['SAMPLE_SIZE'] = qp.datasets.LEQUA2022_SAMPLE_SIZE['T1B'] qp.environ['SAMPLE_SIZE'] = qp.datasets.LEQUA2022_SAMPLE_SIZE['T1B']
qp.environ['N_JOBS'] = -1 qp.environ['N_JOBS'] = -1
optim = 'mrae' optim = 'mrae'
result_dir = f'results/results_lequa_{optim}' result_dir = f'results/lequa/{optim}'
os.makedirs(result_dir, exist_ok=True) os.makedirs(result_dir, exist_ok=True)
hyper_LR = { for method in METHODS:
'classifier__C': np.logspace(-3,3,7),
'classifier__class_weight': ['balanced', None]
}
for method in ['DIR']:#'HDy-OvA', 'SLD', 'ACC-tv', 'PACC-tv']: #['DM', 'DIR']: #'KDEy-MLE', 'KDE-DM', 'DM', 'DIR']:
print('Init method', method) print('Init method', method)
@ -43,39 +34,7 @@ if __name__ == '__main__':
dataset = 'T1B' dataset = 'T1B'
train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset) train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset)
print(f'init {dataset} #instances: {len(train)}') print(f'init {dataset} #instances: {len(train)}')
if method == 'KDEy-MLE': param_grid, quantifier = new_method(method)
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10)
elif method in ['KDE-DM']:
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(LogisticRegression(), target='min_divergence', divergence='l2', montecarlo_trials=5000, val_split=10)
elif method == 'DIR':
param_grid = hyper_LR
quantifier = DIRy(LogisticRegression())
elif method == 'SLD':
param_grid = hyper_LR
quantifier = EMQ(LogisticRegression())
elif method == 'PACC-tv':
param_grid = hyper_LR
quantifier = PACC(LogisticRegression())
elif method == 'ACC-tv':
param_grid = hyper_LR
quantifier = ACC(LogisticRegression())
elif method == 'HDy-OvA':
param_grid = {'binary_quantifier__' + key: val for key, val in hyper_LR.items()}
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
elif method == 'DM':
method_params = {
'nbins': [4,8,16,32],
'val_split': [10, 0.4],
'divergence': ['HD', 'topsoe', 'l2']
}
param_grid = {**method_params, **hyper_LR}
quantifier = DistributionMatching(LogisticRegression())
else:
raise NotImplementedError('unknown method', method)
if param_grid is not None: if param_grid is not None:
modsel = GridSearchQ(quantifier, param_grid, protocol=val_gen, refit=False, n_jobs=-1, verbose=1, error=optim) modsel = GridSearchQ(quantifier, param_grid, protocol=val_gen, refit=False, n_jobs=-1, verbose=1, error=optim)
@ -98,9 +57,4 @@ if __name__ == '__main__':
csv.write(f'{method}\tLeQua-T1B\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n') csv.write(f'{method}\tLeQua-T1B\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
csv.flush() csv.flush()
df = pd.read_csv(result_path+'.csv', sep='\t') show_results(result_path)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
print(pv)

View File

@ -1,3 +1,4 @@
from cgi import test
import os import os
import sys import sys
from typing import Union, Callable from typing import Union, Callable
@ -34,6 +35,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
f'unknown bandwidth_method, valid ones are {KDEy.BANDWIDTH_METHOD}' f'unknown bandwidth_method, valid ones are {KDEy.BANDWIDTH_METHOD}'
assert engine in KDEy.ENGINE, f'unknown engine, valid ones are {KDEy.ENGINE}' assert engine in KDEy.ENGINE, f'unknown engine, valid ones are {KDEy.ENGINE}'
assert target in KDEy.TARGET, f'unknown target, valid ones are {KDEy.TARGET}' assert target in KDEy.TARGET, f'unknown target, valid ones are {KDEy.TARGET}'
assert divergence=='KLD', 'in this version I will only allow KLD as a divergence'
self.classifier = classifier self.classifier = classifier
self.val_split = val_split self.val_split = val_split
self.divergence = divergence self.divergence = divergence
@ -54,7 +56,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
print(f'auto: bandwidth={bandwidth:.5f}') print(f'auto: bandwidth={bandwidth:.5f}')
return bandwidth return bandwidth
def get_kde(self, posteriors): def get_kde_function(self, posteriors):
# if self.bandwidth == 'auto': # if self.bandwidth == 'auto':
# print('adjusting bandwidth') # print('adjusting bandwidth')
# #
@ -116,27 +118,26 @@ class KDEy(AggregativeProbabilisticQuantifier):
self.classifier, y, posteriors, classes, class_count = cross_generate_predictions( self.classifier, y, posteriors, classes, class_count = cross_generate_predictions(
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
) )
print('classifier fit done')
if self.bandwidth == 'auto': if self.bandwidth == 'auto':
self.bandwidth = self.search_bandwidth_maxlikelihood(posteriors, y) self.bandwidth = self.search_bandwidth_maxlikelihood(posteriors, y)
self.val_densities = [self.get_kde(posteriors[y == cat]) for cat in range(data.n_classes)] self.val_densities = [self.get_kde_function(posteriors[y == cat]) for cat in range(data.n_classes)]
self.val_posteriors = posteriors self.val_posteriors = posteriors
if self.target == 'min_divergence': if self.target == 'min_divergence_depr':
self.samples = qp.functional.uniform_prevalence_sampling(n_classes=data.n_classes, size=self.montecarlo_trials) self.samples = qp.functional.uniform_prevalence_sampling(n_classes=data.n_classes, size=self.montecarlo_trials)
self.sample_densities = [self.pdf(kde_i, self.samples) for kde_i in self.val_densities] self.sample_densities = [self.pdf(kde_i, self.samples) for kde_i in self.val_densities]
if self.target == 'min_divergence':
self.class_samples = [kde_i.sample(self.montecarlo_trials, random_state=self.random_state) for kde_i in self.val_densities]
self.class_sample_densities = {}
for ci, samples_i in enumerate(self.class_samples):
self.class_sample_densities[ci] = np.asarray([self.pdf(kde_j, samples_i) for kde_j in self.val_densities]).T
print('kde fit done')
return self return self
#def val_pdf(self, prev):
"""
Returns a function that computes the mixture model with the given prev as mixture factor
:param prev: a prevalence vector, ndarray
:return: a function implementing the validation distribution with fixed mixture factor
"""
# return lambda posteriors: sum(prev_i * self.pdf(kde_i, posteriors) for kde_i, prev_i in zip(self.val_densities, prev))
def aggregate(self, posteriors: np.ndarray): def aggregate(self, posteriors: np.ndarray):
if self.target == 'min_divergence': if self.target == 'min_divergence':
return self._target_divergence(posteriors) return self._target_divergence(posteriors)
@ -145,43 +146,63 @@ class KDEy(AggregativeProbabilisticQuantifier):
else: else:
raise ValueError('unknown target') raise ValueError('unknown target')
def _target_divergence_depr(self, posteriors):
# this variant is, I think, ill-formed, since it evaluates the likelihood on the test points, which are
# overconfident in the KDE-test.
test_density = self.get_kde(posteriors)
# val_test_posteriors = np.concatenate([self.val_posteriors, posteriors])
test_likelihood = self.pdf(test_density, posteriors)
divergence = _get_divergence(self.divergence)
n_classes = len(self.val_densities) # this is the variant I have in the current results, which I think is bugged
# def _target_divergence_depr(self, posteriors):
def match(prev): # # in this variant we evaluate the divergence using a Montecarlo approach
val_pdf = self.val_pdf(prev) # n_classes = len(self.val_densities)
val_likelihood = val_pdf(posteriors) #
# test_kde = self.get_kde_function(posteriors)
return divergence(val_likelihood, test_likelihood) # test_likelihood = self.pdf(test_kde, self.samples)
#
# the initial point is set as the uniform distribution # divergence = _get_divergence(self.divergence)
uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,)) #
# def match(prev):
# solutions are bounded to those contained in the unit-simplex # val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, self.sample_densities))
bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1] # return divergence(val_likelihood, test_likelihood)
constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1 #
r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) # # the initial point is set as the uniform distribution
return r.x # uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,))
#
# # solutions are bounded to those contained in the unit-simplex
# bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1]
# constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1
# r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
# return r.x
def _target_divergence(self, posteriors): def _target_divergence(self, posteriors):
# in this variant we evaluate the divergence using a Montecarlo approach # in this variant we evaluate the divergence using a Montecarlo approach
n_classes = len(self.val_densities) n_classes = len(self.val_densities)
test_kde = self.get_kde(posteriors) test_kde = self.get_kde_function(posteriors)
test_likelihood = self.pdf(test_kde, self.samples) test_densities_per_class = [self.pdf(test_kde, samples_i) for samples_i in self.class_samples]
divergence = _get_divergence(self.divergence) # divergence = _get_divergence(self.divergence)
def kld_monte(pi, qi, eps=1e-8):
# there is no pi in front of the log because the samples are already drawn according to pi
smooth_pi = pi+eps
smooth_qi = qi+eps
return np.mean(np.log(smooth_pi / smooth_qi))
def match(prev): def match(prev):
val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, self.sample_densities)) # choose the samples according to the prevalence vector
return divergence(val_likelihood, test_likelihood) # e.g., prev = [0.5, 0.3, 0.2] will draw 50% from KDE_0, 30% from KDE_1, and 20% from KDE_2
# the points are already pre-sampled and de densities are pre-computed, so that now all that remains
# is to pick a proportional number of each from each class (same for test)
num_variates_per_class = np.round(prev * self.montecarlo_trials).astype(int)
sample_densities = np.vstack(
[self.class_sample_densities[ci][:num_i] for ci, num_i in enumerate(num_variates_per_class)]
)
#val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip(prev, sample_densities.T))
val_likelihood = prev @ sample_densities.T
#test_likelihood = []
#for samples_i, num_i in zip(test_densities_per_class, num_variates_per_class):
# test_likelihood.append(samples_i[:num_i])
#test_likelihood = np.concatenate[test]
test_likelihood = np.concatenate(
[samples_i[:num_i] for samples_i, num_i in zip(test_densities_per_class, num_variates_per_class)]
)
return kld_monte(val_likelihood, test_likelihood)
# the initial point is set as the uniform distribution # the initial point is set as the uniform distribution
uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,)) uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,))

View File

@ -1,14 +1,9 @@
import pickle import pickle
import numpy as np
from sklearn.linear_model import LogisticRegression
import os import os
import sys
import pandas as pd import pandas as pd
from distribution_matching.commons import METHODS, new_method, show_results
import quapy as qp import quapy as qp
from quapy.method.aggregative import EMQ, DistributionMatching, PACC, ACC, CC, PCC, HDy, OneVsAllAggregative
from method_kdey import KDEy
from method_dirichlety import DIRy
from quapy.model_selection import GridSearchQ from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP from quapy.protocol import UPP
@ -21,16 +16,11 @@ if __name__ == '__main__':
n_bags_val = 250 n_bags_val = 250
n_bags_test = 1000 n_bags_test = 1000
optim = 'mae' optim = 'mae'
result_dir = f'results/results_tweet_{optim}_redohyper' result_dir = f'results/tweet/{optim}'
os.makedirs(result_dir, exist_ok=True) os.makedirs(result_dir, exist_ok=True)
hyper_LR = { for method in METHODS:
'classifier__C': np.logspace(-3,3,7),
'classifier__class_weight': ['balanced', None]
}
for method in ['CC', 'SLD', 'PCC', 'PACC-tv', 'ACC-tv', 'DM', 'HDy-OvA', 'KDEy-MLE', 'KDE-DM', 'DIR']:
print('Init method', method) print('Init method', method)
@ -59,67 +49,7 @@ if __name__ == '__main__':
if not is_semeval or not semeval_trained: if not is_semeval or not semeval_trained:
if method == 'KDE': # not used param_grid, quantifier = new_method(method)
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
elif method == 'KDEy-MLE':
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10)
elif method in ['KDE-DM']:
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
param_grid = {**method_params, **hyper_LR}
quantifier = KDEy(LogisticRegression(), target='min_divergence', divergence='l2', montecarlo_trials=5000, val_split=10)
elif method == 'DIR':
param_grid = hyper_LR
quantifier = DIRy(LogisticRegression())
elif method == 'SLD':
param_grid = hyper_LR
quantifier = EMQ(LogisticRegression())
elif method == 'PACC-tv':
param_grid = hyper_LR
quantifier = PACC(LogisticRegression())
#elif method == 'PACC-kfcv':
# param_grid = hyper_LR
# quantifier = PACC(LogisticRegression(), val_split=10)
elif method == 'PACC':
method_params = {'val_split': [10, 0.4]}
param_grid = {**method_params, **hyper_LR}
quantifier = PACC(LogisticRegression())
elif method == 'ACC':
method_params = {'val_split': [10, 0.4]}
param_grid = {**method_params, **hyper_LR}
quantifier = ACC(LogisticRegression())
elif method == 'PCC':
param_grid = hyper_LR
quantifier = PCC(LogisticRegression())
elif method == 'ACC-tv':
param_grid = hyper_LR
quantifier = ACC(LogisticRegression())
elif method == 'CC':
param_grid = hyper_LR
quantifier = CC(LogisticRegression())
elif method == 'HDy-OvA':
param_grid = {'binary_quantifier__'+key:val for key,val in hyper_LR.items()}
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
#elif method == 'DM':
# param_grid = {
# 'nbins': [5,10,15],
# 'classifier__C': np.logspace(-4,4,9),
# 'classifier__class_weight': ['balanced', None]
# }
# quantifier = DistributionMatching(LogisticRegression())
elif method == 'DM':
method_params = {
'nbins': [4,8,16,32],
'val_split': [10, 0.4],
'divergence': ['HD', 'topsoe', 'l2']
}
param_grid = {**method_params, **hyper_LR}
quantifier = DistributionMatching(LogisticRegression())
else:
raise NotImplementedError('unknown method', method)
# model selection # model selection
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=True) data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=True)
@ -151,9 +81,4 @@ if __name__ == '__main__':
csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n') csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
csv.flush() csv.flush()
df = pd.read_csv(global_result_path+'.csv', sep='\t') show_results(global_result_path)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
print(pv)

View File

@ -11,7 +11,8 @@ def prediction(
model: BaseQuantifier, model: BaseQuantifier,
protocol: AbstractProtocol, protocol: AbstractProtocol,
aggr_speedup: Union[str, bool] = 'auto', aggr_speedup: Union[str, bool] = 'auto',
verbose=False): verbose=False,
verbose_error=None):
""" """
Uses a quantification model to generate predictions for the samples generated via a specific protocol. Uses a quantification model to generate predictions for the samples generated via a specific protocol.
This function is central to all evaluation processes, and is endowed with an optimization to speed-up the This function is central to all evaluation processes, and is endowed with an optimization to speed-up the
@ -32,6 +33,7 @@ def prediction(
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
convenient or not. Set to False to deactivate. convenient or not. Set to False to deactivate.
:param verbose: boolean, show or not information in stdout :param verbose: boolean, show or not information in stdout
:param verbose_error: an evaluation function to be used to display intermediate results if verbose=True (default None)
:return: a tuple `(true_prevs, estim_prevs)` in which each element in the tuple is an array of shape :return: a tuple `(true_prevs, estim_prevs)` in which each element in the tuple is an array of shape
`(n_samples, n_classes)` containing the true, or predicted, prevalence values for each sample `(n_samples, n_classes)` containing the true, or predicted, prevalence values for each sample
""" """
@ -61,16 +63,21 @@ def prediction(
if apply_optimization: if apply_optimization:
pre_classified = model.classify(protocol.get_labelled_collection().instances) pre_classified = model.classify(protocol.get_labelled_collection().instances)
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified) protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose) return __prediction_helper(model.aggregate, protocol_with_predictions, verbose, verbose_error)
else: else:
return __prediction_helper(model.quantify, protocol, verbose) return __prediction_helper(model.quantify, protocol, verbose, verbose_error)
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False): def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False, verbose_error=None):
true_prevs, estim_prevs = [], [] true_prevs, estim_prevs = [], []
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total(), desc='predicting') if verbose else protocol(): if verbose:
pbar = tqdm(protocol(), total=protocol.total(), desc='predicting')
for sample_instances, sample_prev in pbar if verbose else protocol():
estim_prevs.append(quantification_fn(sample_instances)) estim_prevs.append(quantification_fn(sample_instances))
true_prevs.append(sample_prev) true_prevs.append(sample_prev)
if verbose and verbose_error is not None:
err = verbose_error(true_prevs, estim_prevs)
pbar.set_description('predicting: ongoing error={err:.5f}')
true_prevs = np.asarray(true_prevs) true_prevs = np.asarray(true_prevs)
estim_prevs = np.asarray(estim_prevs) estim_prevs = np.asarray(estim_prevs)

View File

@ -1091,7 +1091,7 @@ class T50(ThresholdOptimization):
Threshold Optimization variant for :class:`ACC` as proposed by Threshold Optimization variant for :class:`ACC` as proposed by
`Forman 2006 <https://dl.acm.org/doi/abs/10.1145/1150402.1150423>`_ and `Forman 2006 <https://dl.acm.org/doi/abs/10.1145/1150402.1150423>`_ and
`Forman 2008 <https://link.springer.com/article/10.1007/s10618-008-0097-y>`_ that looks `Forman 2008 <https://link.springer.com/article/10.1007/s10618-008-0097-y>`_ that looks
for the threshold that makes `tpr` cosest to 0.5. for the threshold that makes `tpr` closest to 0.5.
The goal is to bring improved stability to the denominator of the adjustment. The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier :param classifier: a sklearn's Estimator that generates a classifier
@ -1179,7 +1179,7 @@ class MS(ThresholdOptimization):
super().__init__(classifier, val_split) super().__init__(classifier, val_split)
def _condition(self, tpr, fpr) -> float: def _condition(self, tpr, fpr) -> float:
pass return True
def _optimize_threshold(self, y, probabilities): def _optimize_threshold(self, y, probabilities):
tprs = [] tprs = []
@ -1190,9 +1190,26 @@ class MS(ThresholdOptimization):
TP, FP, FN, TN = self._compute_table(y, y_) TP, FP, FN, TN = self._compute_table(y, y_)
tpr = self._compute_tpr(TP, FP) tpr = self._compute_tpr(TP, FP)
fpr = self._compute_fpr(FP, TN) fpr = self._compute_fpr(FP, TN)
if self._condition(tpr, fpr):
tprs.append(tpr) tprs.append(tpr)
fprs.append(fpr) fprs.append(fpr)
return np.median(tprs), np.median(fprs) return tprs, fprs
def aggregate(self, classif_predictions):
prevs_estim = self.cc.aggregate(classif_predictions)
positive_prevs = []
for tpr, fpr in zip(self.tpr, self.fpr):
if tpr - fpr > 0:
acc = np.clip((prevs_estim[1] - fpr) / (tpr - fpr), 0, 1)
positive_prevs.append(acc)
if len(positive_prevs) > 0:
adjusted_positive_prev = np.median(positive_prevs)
adjusted_prevs_estim = np.array((1 - adjusted_positive_prev, adjusted_positive_prev))
return adjusted_prevs_estim
else:
return prevs_estim
class MS2(MS): class MS2(MS):
@ -1215,19 +1232,9 @@ class MS2(MS):
def __init__(self, classifier: BaseEstimator, val_split=0.4): def __init__(self, classifier: BaseEstimator, val_split=0.4):
super().__init__(classifier, val_split) super().__init__(classifier, val_split)
def _optimize_threshold(self, y, probabilities): def _condition(self, tpr, fpr) -> float:
tprs = [0, 1] return (tpr - fpr) > 0.25
fprs = [0, 1]
candidate_thresholds = np.unique(probabilities[:, 1])
for candidate_threshold in candidate_thresholds:
y_ = [self.classes_[1] if p > candidate_threshold else self.classes_[0] for p in probabilities[:, 1]]
TP, FP, FN, TN = self._compute_table(y, y_)
tpr = self._compute_tpr(TP, FP)
fpr = self._compute_fpr(FP, TN)
if (tpr - fpr) > 0.25:
tprs.append(tpr)
fprs.append(fpr)
return np.median(tprs), np.median(fprs)
ClassifyAndCount = CC ClassifyAndCount = CC