49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
|
from sklearn.model_selection import GridSearchCV
|
||
|
import numpy as np
|
||
|
import quapy as qp
|
||
|
from sklearn.linear_model import LogisticRegression
|
||
|
|
||
|
sample_size = 500
|
||
|
qp.environ['SAMPLE_SIZE'] = sample_size
|
||
|
|
||
|
|
||
|
|
||
|
def gen_data():
|
||
|
|
||
|
data = qp.datasets.fetch_reviews('kindle', tfidf=True, min_df=5)
|
||
|
|
||
|
models = [
|
||
|
qp.method.aggregative.CC,
|
||
|
qp.method.aggregative.ACC,
|
||
|
qp.method.aggregative.PCC,
|
||
|
qp.method.aggregative.PACC,
|
||
|
qp.method.aggregative.HDy,
|
||
|
qp.method.aggregative.EMQ,
|
||
|
qp.method.meta.ECC,
|
||
|
qp.method.meta.EACC,
|
||
|
qp.method.meta.EHDy,
|
||
|
]
|
||
|
|
||
|
method_names, true_prevs, estim_prevs, tr_prevs = [], [], [], []
|
||
|
for Quantifier in models:
|
||
|
print(f'training {Quantifier.__name__}')
|
||
|
lr = LogisticRegression(max_iter=1000, class_weight='balanced')
|
||
|
# lr = GridSearchCV(lr, param_grid={'C':np.logspace(-3,3,7)}, n_jobs=-1)
|
||
|
model = Quantifier(lr).fit(data.training)
|
||
|
true_prev, estim_prev = qp.evaluation.artificial_sampling_prediction(
|
||
|
model, data.test, sample_size, n_repetitions=20, n_prevpoints=11)
|
||
|
|
||
|
method_names.append(Quantifier.__name__)
|
||
|
true_prevs.append(true_prev)
|
||
|
estim_prevs.append(estim_prev)
|
||
|
tr_prevs.append(data.training.prevalence())
|
||
|
|
||
|
return method_names, true_prevs, estim_prevs, tr_prevs
|
||
|
|
||
|
method_names, true_prevs, estim_prevs, tr_prevs = qp.util.pickled_resource('./plots/plot_data.pkl', gen_data)
|
||
|
|
||
|
qp.plot.error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=11, savepath='./plots/err_drift.png')
|
||
|
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, savepath='./plots/bin_diag.png')
|
||
|
qp.plot.binary_bias_global(method_names, true_prevs, estim_prevs, savepath='./plots/bin_bias.png')
|
||
|
qp.plot.binary_bias_bins(method_names, true_prevs, estim_prevs, nbins=11, savepath='./plots/bin_bias_bin.png')
|