Compare commits
10 Commits
master
...
transducti
Author | SHA1 | Date |
---|---|---|
Alejandro Moreo Fernandez | 4e016c7596 | |
Alejandro Moreo Fernandez | 4f1ac49030 | |
Alejandro Moreo Fernandez | e267719164 | |
Alejandro Moreo Fernandez | e6e8ed87fd | |
Alejandro Moreo Fernandez | adfa235cce | |
Alejandro Moreo Fernandez | 750b44aedb | |
Alejandro Moreo Fernandez | 24e755dcc1 | |
Alejandro Moreo Fernandez | fb2390e8d7 | |
Alejandro Moreo Fernandez | bfaa5678d7 | |
Alejandro Moreo Fernandez | cbe3f410ed |
|
@ -0,0 +1,2 @@
|
|||
En old stuff hay cosas interesantes, está bien escrita la motivación, aunque quiero rehacer esos métodos
|
||||
con una abstracción mejor hecha.
|
|
@ -0,0 +1,468 @@
|
|||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
from sklearn import clone
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import PACC, _training_helper, PCC
|
||||
from quapy.method.base import BaseQuantifier
|
||||
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
# ideas: the observation proves that if you have a validation set from the target distribution, then it "repairs"
|
||||
# the predictions of the classifier. This might sound as a triviliaty, but note that the classifier is trained on
|
||||
# another distribution. So one could take a look at the test set (w/o labels) and extract a portion of the entire
|
||||
# labelled collection that matches the test set well, and keep the remainder as the training set on which to train
|
||||
# the classifier. (The version implemented so far follows a different heuristic, based on having a validation split
|
||||
# which is iid wrt the training set, and using this validation split to extract another validation split closer to the
|
||||
# test distribution.
|
||||
|
||||
# note: the T3 variant (the iterative one) admits two variants: (i) the estimated test prev is used to sample, via
|
||||
# artificial sampling, a sample from the validation that reflects the desired prevalence; (ii) the test prev is used
|
||||
# to compute the weights that compensate (i.e., rebalance) the relative importance of each of the current samples
|
||||
# wrt to the believed prevalence. Both are implemented, but the current one is the (ii), and (i) is commented
|
||||
|
||||
|
||||
class TransductivePACC(BaseQuantifier):
|
||||
"""
|
||||
PACC works by adjusting the PCC estimate applying a linear correction. This correction assumes P(X|Y) is fixed
|
||||
between the training and test distributions, meaning that the missclassification rates estimated in the training
|
||||
distribution (e.g., by means of a train/val split, or by means of k-FCV) is a good representative of the
|
||||
missclassification rates in the test. In situations in which the training and test distributions are shifted, and
|
||||
in which P(X|Y) cannot be assumed to remain constant (e.g., in contexts of covariate shift), this adjustment
|
||||
can be arbitrarily harmful. Transductive quantifiers decide the correction as a function of the test set.
|
||||
TransductivePACC in particular implements this intuition by picking a validation subset from the training set
|
||||
such that it is close to the test set. In this preliminary example, we simply rely on distances for choosing
|
||||
points close to every test point. The missclassification rates are estimated in this "transductive" validation
|
||||
split.
|
||||
|
||||
:param learner:
|
||||
:param how_many:
|
||||
:param metric:
|
||||
"""
|
||||
|
||||
def __init__(self, learner, how_many=1, metric='euclidean'):
|
||||
self.learner = learner
|
||||
self.how_many = how_many
|
||||
self.metric = metric
|
||||
|
||||
def quantify(self, instances):
|
||||
validation_index = self.get_closer_val_intances(instances, how_many=self.how_many, metric=self.metric)
|
||||
validation_selected = self.validation_pool.sampling_from_index(validation_index)
|
||||
pacc = PACC(self.learner, val_split=validation_selected)
|
||||
pacc.fit(None, fit_learner=False)
|
||||
self.to_show_val_selected = validation_selected # todo: remove
|
||||
return pacc.quantify(instances)
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
|
||||
if isinstance(val_split, float):
|
||||
self.training, self.validation_pool = data.split_stratified(1-val_split)
|
||||
elif isinstance(val_split, LabelledCollection):
|
||||
self.training = data
|
||||
self.validation_pool = val_split
|
||||
else:
|
||||
raise ValueError('val_split data type not understood')
|
||||
self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def get_closer_val_intances(self, T, how_many=1, metric='euclidean'):
|
||||
"""
|
||||
Takes "how_many" instances (indices) from X that are the closes to every instance in T
|
||||
:param T: test instances
|
||||
:param how_many: how many samples to choose for every test datapoint
|
||||
:param metric: similarity function (see `scipy.spatial.distance.cdist`)
|
||||
:return: ndarray with indices of validation_pool's datapoints
|
||||
"""
|
||||
dist = cdist(T, self.validation_pool.instances, metric=metric)
|
||||
indexes = np.argsort(dist, axis=1)[:, :how_many].flatten()
|
||||
return indexes
|
||||
|
||||
|
||||
class TransductiveInvdistancePACC(BaseQuantifier):
|
||||
"""
|
||||
This is a modification of TransductivePACC. The idea is that, instead of choosing the closest validation points,
|
||||
we could select all validation points but weighted inversely proportionally to the distance.
|
||||
The main objective here is to repair the performance of the t-quantifier in cases of PPS.
|
||||
|
||||
:param learner:
|
||||
:param how_many:
|
||||
:param metric:
|
||||
"""
|
||||
|
||||
def __init__(self, learner, metric='euclidean'):
|
||||
self.learner = learner
|
||||
self.metric = metric
|
||||
|
||||
def quantify(self, instances):
|
||||
validation_similarities = self.get_val_similarities(instances, metric=self.metric)
|
||||
validation_weight = validation_similarities.sum(axis=0)
|
||||
validation_posteriors = self.learner.predict_proba(self.validation_pool.instances)
|
||||
positive_posteriors = validation_posteriors[self.validation_pool.labels == 1][:,1]
|
||||
negative_posteriors = validation_posteriors[self.validation_pool.labels == 0][:,1]
|
||||
positive_weights = validation_weight[self.validation_pool.labels == 1]
|
||||
negative_weights = validation_weight[self.validation_pool.labels == 0]
|
||||
|
||||
soft_tpr = (positive_posteriors*positive_weights).sum()/(positive_weights.sum())
|
||||
soft_fpr = (negative_posteriors*negative_weights).sum()/(negative_weights.sum())
|
||||
|
||||
pcc = PCC(learner=self.learner).quantify(instances)
|
||||
adjusted = (pcc[1] - soft_fpr)/(soft_tpr-soft_fpr)
|
||||
adjusted = np.clip(adjusted, 0, 1)
|
||||
return np.asarray([1-adjusted,adjusted])
|
||||
|
||||
def set_params(self, **parameters):
|
||||
pass
|
||||
|
||||
def get_params(self, deep=True):
|
||||
pass
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
|
||||
if isinstance(val_split, float):
|
||||
self.training, self.validation_pool = data.split_stratified(1-val_split)
|
||||
elif isinstance(val_split, LabelledCollection):
|
||||
self.training = data
|
||||
self.validation_pool = val_split
|
||||
else:
|
||||
raise ValueError('val_split data type not understood')
|
||||
self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def get_val_similarities(self, T, metric='euclidean'):
|
||||
"""
|
||||
Takes "how_many" instances (indices) from X that are the closes to every instance in T
|
||||
:param T: test instances
|
||||
:param metric: similarity function (see `scipy.spatial.distance.cdist`)
|
||||
:return: ndarray with indices of validation_pool's datapoints
|
||||
"""
|
||||
# dist = cdist(T, self.validation_pool.instances, metric=metric)
|
||||
# norm_dist = (dist/np.max(dist))
|
||||
# sim = 1 - norm_dist # other variants: divide by the max distance for each test point, and not overall distance
|
||||
# norm_sim = normalize(sim**2, norm='l1') # <-- this kinds of helps
|
||||
# return norm_sim
|
||||
|
||||
dist = cdist(T, self.validation_pool.instances, metric=metric)
|
||||
# dist = dist**4 # <--
|
||||
norm_dist = (dist / np.max(dist))
|
||||
sim = 1 - norm_dist # other variants: divide by the max distance for each test point, and not overall distance
|
||||
norm_sim = normalize(sim**4, norm='l1') # <-- this kinds helps a lot and don't know why
|
||||
return norm_sim
|
||||
|
||||
# this doesn't work at all (dont know why)
|
||||
# cut_dist = np.median(dist)/3
|
||||
# dist[dist>cut_dist]=cut_dist
|
||||
# norm_dist = (dist / cut_dist)
|
||||
# sim = 1 - norm_dist # other variants: divide by the max distance for each test point, and not overall distance
|
||||
# norm_sim = normalize(sim, norm='l1')
|
||||
# return norm_sim
|
||||
|
||||
|
||||
class TransductiveInvdistanceIterativePACC(BaseQuantifier):
|
||||
"""
|
||||
This is a modification of TransductiveInvdistancePACC.
|
||||
The idea is that, to also consider in the weight the importance prev_test / prev_train (where prev_test has to be
|
||||
estimated by means of an auxiliary quantifier).
|
||||
|
||||
:param learner:
|
||||
:param metric:
|
||||
"""
|
||||
|
||||
def __init__(self, learner, metric='euclidean', oracle_test_prev=None):
|
||||
self.learner = learner
|
||||
self.metric = metric
|
||||
self.oracle_test_prev = oracle_test_prev
|
||||
|
||||
def quantify(self, instances):
|
||||
|
||||
if self.oracle_test_prev is None:
|
||||
proxy = TransductiveInvdistancePACC(learner=clone(self.learner)).fit(training, val_split=self.validation_pool)
|
||||
test_prev = proxy.quantify(instances)
|
||||
#print(f'\ttest_prev_estimated={F.strprev(test_prev)}')
|
||||
else:
|
||||
test_prev = self.oracle_test_prev
|
||||
|
||||
#size = len(self.validation_pool)
|
||||
#validation = self.validation_pool.sampling(size, *test_prev[:-1])
|
||||
validation = self.validation_pool
|
||||
|
||||
validation_similarities = self.get_val_similarities(instances, validation, metric=self.metric, test_prev_estim=test_prev)
|
||||
validation_weight = validation_similarities.sum(axis=0)
|
||||
validation_posteriors = self.learner.predict_proba(validation.instances)
|
||||
positive_posteriors = validation_posteriors[validation.labels == 1][:,1]
|
||||
negative_posteriors = validation_posteriors[validation.labels == 0][:,1]
|
||||
positive_weights = validation_weight[validation.labels == 1]
|
||||
negative_weights = validation_weight[validation.labels == 0]
|
||||
|
||||
soft_tpr = (positive_posteriors*positive_weights).sum()/(positive_weights.sum())
|
||||
soft_fpr = (negative_posteriors*negative_weights).sum()/(negative_weights.sum())
|
||||
|
||||
pcc = PCC(learner=self.learner).quantify(instances)
|
||||
adjusted = (pcc[1] - soft_fpr)/(soft_tpr-soft_fpr)
|
||||
adjusted = np.clip(adjusted, 0, 1)
|
||||
return np.asarray([1-adjusted, adjusted])
|
||||
|
||||
def set_params(self, **parameters):
|
||||
pass
|
||||
|
||||
def get_params(self, deep=True):
|
||||
pass
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split=Union[float,LabelledCollection]):
|
||||
if isinstance(val_split, float):
|
||||
self.training, self.validation_pool = data.split_stratified(1-val_split)
|
||||
elif isinstance(val_split, LabelledCollection):
|
||||
self.training = data
|
||||
self.validation_pool = val_split
|
||||
else:
|
||||
raise ValueError('val_split data type not understood')
|
||||
self.learner, _ = _training_helper(self.learner, self.training, fit_learner=True, ensure_probabilistic=True)
|
||||
return self
|
||||
|
||||
def get_val_similarities(self, T, validation, metric='euclidean', test_prev_estim=None):
|
||||
"""
|
||||
Takes "how_many" instances (indices) from X that are the closes to every instance in T
|
||||
:param T: test instances
|
||||
:param metric: similarity function (see `scipy.spatial.distance.cdist`)
|
||||
:return: ndarray with indices of validation_pool's datapoints
|
||||
"""
|
||||
|
||||
dist = cdist(T, validation.instances, metric=metric)
|
||||
# dist = dist**4 # <--
|
||||
norm_dist = (dist / np.max(dist))
|
||||
sim = 1 - norm_dist # other variants: divide by the max distance for each test point, and not overall distance
|
||||
norm_sim = normalize(sim ** 4, norm='l1') # <-- this kinds helps a lot and don't know why
|
||||
|
||||
if test_prev_estim is not None:
|
||||
pos_reweight = test_prev_estim[1] / validation.prevalence()[1]
|
||||
neg_reweight = test_prev_estim[0] / validation.prevalence()[0]
|
||||
|
||||
pos_reweight /= (pos_reweight + neg_reweight)
|
||||
neg_reweight /= (pos_reweight + neg_reweight)
|
||||
|
||||
rebalance_weight = np.zeros(len(validation))
|
||||
rebalance_weight[validation.labels == 1] = pos_reweight
|
||||
rebalance_weight[validation.labels == 0] = neg_reweight
|
||||
|
||||
rebalance_weight /= rebalance_weight.sum()
|
||||
|
||||
# norm_sim = normalize(sim, norm='l1')
|
||||
norm_sim *= rebalance_weight
|
||||
norm_sim = normalize(norm_sim**3, norm='l1')
|
||||
|
||||
return norm_sim
|
||||
|
||||
# norm_sim = normalize(sim, norm='l1') # <-- this kinds helps a lot and don't know why
|
||||
# norm_sim = normalize(norm_sim**2, norm='l1') # <-- this kinds helps a lot and don't know why
|
||||
#return norm_sim
|
||||
|
||||
|
||||
def plot_samples(val_orig:LabelledCollection, val_sel:LabelledCollection, test):
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
font = {'family': 'normal',
|
||||
'weight': 'bold',
|
||||
'size': 10}
|
||||
matplotlib.rc('font', **font)
|
||||
size=0.5
|
||||
alpha=0.25
|
||||
|
||||
# plot 1:
|
||||
instances, labels = val_orig.Xy
|
||||
x1 = instances[:,0]
|
||||
x2 = instances[:,1]
|
||||
|
||||
# plt.ion()
|
||||
# plt.show()
|
||||
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.scatter(x1[labels==0], x2[labels==0], s=size, alpha=alpha)
|
||||
plt.scatter(x1[labels==1], x2[labels==1], s=size, alpha=alpha)
|
||||
plt.title('Validation Pool')
|
||||
|
||||
# plot 2:
|
||||
instances, labels = val_sel.Xy
|
||||
x1 = instances[:, 0]
|
||||
x2 = instances[:, 1]
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.scatter(x1[labels == 0], x2[labels == 0], s=size, alpha=alpha)
|
||||
plt.scatter(x1[labels == 1], x2[labels == 1], s=size, alpha=alpha)
|
||||
plt.title('Validation Choosen')
|
||||
|
||||
# plot 3:
|
||||
instances, labels = test.Xy
|
||||
x1 = instances[:, 0]
|
||||
x2 = instances[:, 1]
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
# plt.scatter(x1, x2, s=size, alpha=alpha)
|
||||
plt.scatter(x1[labels == 0], x2[labels == 0], s=size, alpha=alpha)
|
||||
plt.scatter(x1[labels == 1], x2[labels == 1], s=size, alpha=alpha)
|
||||
plt.title('Test')
|
||||
|
||||
# plt.draw()
|
||||
# plt.pause(0.001)
|
||||
plt.show()
|
||||
|
||||
|
||||
class Distribution:
|
||||
def sample(self, n): pass
|
||||
|
||||
|
||||
class ThreeGMDist(Distribution):
|
||||
"""
|
||||
Three Gaussian Mixture Distribution, with one negative normal, and two positive normals
|
||||
"""
|
||||
def __init__(self, mean_neg, cov_neg, mean_pos_A, cov_pos_A, mean_pos_B, cov_pos_B, prior_pos, prior_A):
|
||||
assert 0<=prior_pos<=1, 'pos_prior out of range'
|
||||
assert len(mean_neg) == len(mean_pos_A) == len(mean_pos_B), 'dimension missmatch'
|
||||
#todo check for cov dimensions
|
||||
self.mean_neg = mean_neg
|
||||
self.cov_neg = cov_neg
|
||||
self.mean_pos_A = mean_pos_A
|
||||
self.cov_pos_A = cov_pos_A
|
||||
self.mean_pos_B = mean_pos_B
|
||||
self.cov_pos_B = cov_pos_B
|
||||
self.prior_pos = prior_pos
|
||||
self.prior_A = prior_A
|
||||
|
||||
def sample(self, n):
|
||||
npos = int(n*self.prior_pos)
|
||||
nneg = n-npos
|
||||
nposA = int(npos*self.prior_A)
|
||||
nposB = npos-nposA
|
||||
neg = np.random.multivariate_normal(mean=self.mean_neg, cov=self.cov_neg, size=nneg)
|
||||
pos_A = np.random.multivariate_normal(mean=self.mean_pos_A, cov=self.cov_pos_A, size=nposA) # hard
|
||||
pos_B = np.random.multivariate_normal(mean=self.mean_pos_B, cov=self.cov_pos_B, size=nposB) # easy
|
||||
return LabelledCollection(
|
||||
instances=np.concatenate([neg, pos_A, pos_B]),
|
||||
labels=[0]*nneg + [1]*(nposA+nposB)
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
print('proof of concept')
|
||||
|
||||
def test(q, testset, methodtag, show=False, scores=None):
|
||||
estim_prev = q.quantify(testset.instances)
|
||||
ae = qp.error.ae(testset.prevalence(), estim_prev)
|
||||
print(f'{methodtag}\tpredicts={F.strprev(estim_prev)} true={F.strprev(testset.prevalence())} with an AE of {ae:.4f}')
|
||||
if show:
|
||||
plot_samples(q.validation_pool, q.to_show_val_selected, testset)
|
||||
if scores is not None:
|
||||
scores.append(ae)
|
||||
return ae
|
||||
|
||||
def rand():
|
||||
return np.random.rand()
|
||||
|
||||
def cls():
|
||||
return LogisticRegression()
|
||||
|
||||
def scores():
|
||||
return {
|
||||
'i-PACC': [],
|
||||
'i-PCC': [],
|
||||
't-PACC': [],
|
||||
't2-PACC': [],
|
||||
't3-PACC': [],
|
||||
}
|
||||
|
||||
score_shift = {
|
||||
'pps': scores(),
|
||||
'cov': scores(),
|
||||
'covs': scores(),
|
||||
}
|
||||
|
||||
for i in range(1000):
|
||||
|
||||
mneg, covneg = [0, 0], [[1, 0], [0, 1]]
|
||||
mposA, covposA = [2, 0], [[1, 0], [0, 1]]
|
||||
mposB, covposB = [3, 3], [[1, 0], [0, 1]]
|
||||
source_dist = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=0.5, prior_A=0.5)
|
||||
target_dist_pps = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=rand(), prior_A=0.5)
|
||||
target_dist_covs = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=0.5, prior_A=rand())
|
||||
target_dist_covs_pps = ThreeGMDist(mneg, covneg, mposA, covposA, mposB, covposB, prior_pos=rand(), prior_A=rand())
|
||||
|
||||
training = source_dist.sample(1000)
|
||||
validation_iid = source_dist.sample(1000)
|
||||
test_pps = target_dist_pps.sample(1000)
|
||||
val_pps = target_dist_pps.sample(1000)
|
||||
test_cov = target_dist_covs.sample(1000)
|
||||
val_cov = target_dist_covs.sample(1000)
|
||||
test_cov_pps = target_dist_covs_pps.sample(1000)
|
||||
val_cov_pps = target_dist_covs_pps.sample(1000)
|
||||
|
||||
#print('observacion:')
|
||||
#inductive_pacc = PACC(cls())
|
||||
#inductive_pacc.fit(training, val_split=val_cov)
|
||||
#test(inductive_pacc, test_cov, 'i-PACC (val covs) on covariate shift')
|
||||
#inductive_pacc.fit(training, val_split=val_cov_pps)
|
||||
#test(inductive_pacc, test_cov_pps, 'i-PACC (val val_cov_pps) on covariate & prior shift')
|
||||
|
||||
inductive_pacc = PACC(cls())
|
||||
inductive_pacc.fit(training, val_split=validation_iid)
|
||||
|
||||
inductive_pcc = PCC(cls())
|
||||
inductive_pcc.fit(training)
|
||||
|
||||
transductive_pacc = TransductivePACC(cls(), how_many=1)
|
||||
transductive_pacc.fit(training, val_split=validation_iid)
|
||||
|
||||
transductive_pacc2 = TransductiveInvdistancePACC(cls())
|
||||
transductive_pacc2.fit(training, val_split=validation_iid)
|
||||
|
||||
transductive_pacc3 = TransductiveInvdistanceIterativePACC(cls())
|
||||
transductive_pacc3.fit(training, val_split=validation_iid)
|
||||
|
||||
print('\nPrior Probability Shift')
|
||||
print('-'*80)
|
||||
test(inductive_pacc, test_pps, 'i-PACC', scores=score_shift['pps']['i-PACC'])
|
||||
test(inductive_pcc, test_pps, 'i-PCC', scores=score_shift['pps']['i-PCC'])
|
||||
test(transductive_pacc, test_pps, 't-PACC', show=False, scores=score_shift['pps']['t-PACC'])
|
||||
test(transductive_pacc2, test_pps, 't2-PACC', show=False, scores=score_shift['pps']['t2-PACC'])
|
||||
test(transductive_pacc3, test_pps, 't3-PACC', show=False, scores=score_shift['pps']['t3-PACC'])
|
||||
|
||||
print('\nCovariate Shift')
|
||||
print('-' * 80)
|
||||
test(inductive_pacc, test_cov, 'i-PACC', scores=score_shift['cov']['i-PACC'])
|
||||
test(inductive_pcc, test_cov, 'i-PCC', scores=score_shift['cov']['i-PCC'])
|
||||
test(transductive_pacc, test_cov, 't-PACC', show=False, scores=score_shift['cov']['t-PACC'])
|
||||
test(transductive_pacc2, test_cov, 't2-PACC', show=False, scores=score_shift['cov']['t2-PACC'])
|
||||
test(transductive_pacc3, test_cov, 't3-PACC', show=False, scores=score_shift['cov']['t3-PACC'])
|
||||
|
||||
print('\nCovariate Shift- TYPEII')
|
||||
print('-' * 80)
|
||||
test(inductive_pacc, test_cov_pps, 'i-PACC', scores=score_shift['covs']['i-PACC'])
|
||||
test(inductive_pcc, test_cov_pps, 'i-PCC', scores=score_shift['covs']['i-PCC'])
|
||||
test(transductive_pacc, test_cov_pps, 't-PACC', show=False, scores=score_shift['covs']['t-PACC'])
|
||||
test(transductive_pacc2, test_cov_pps, 't2-PACC', scores=score_shift['covs']['t2-PACC'])
|
||||
test(transductive_pacc3, test_cov_pps, 't3-PACC', scores=score_shift['covs']['t3-PACC'])
|
||||
|
||||
for shift in score_shift.keys():
|
||||
print(shift)
|
||||
for method in score_shift[shift]:
|
||||
print(f'\t{method}: {np.mean(score_shift[shift][method]):.4f}')
|
||||
|
||||
# print()
|
||||
# print('-'*80)
|
||||
# # proposed method
|
||||
#
|
||||
# transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_pps.prevalence())
|
||||
# transductive_pacc.fit(training, val_split=validation_iid)
|
||||
# test(transductive_pacc, test_pps, 't3(oracle)-PACC on prior probability shift', show=False)
|
||||
#
|
||||
# transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_cov.prevalence())
|
||||
# transductive_pacc.fit(training, val_split=validation_iid)
|
||||
# test(transductive_pacc, test_cov, 't3(oracle)-PACC on covariate shift', show=False)
|
||||
#
|
||||
# transductive_pacc = TransductiveInvdistanceIterativePACC(cls(), oracle_test_prev=test_cov_pps.prevalence())
|
||||
# transductive_pacc.fit(training, val_split=validation_iid)
|
||||
# test(transductive_pacc, test_cov_pps, 't3(oracle)-PACC on covariate & prior shift')
|
||||
|
|
@ -0,0 +1,427 @@
|
|||
import itertools
|
||||
from typing import Iterable
|
||||
|
||||
from densratio import densratio
|
||||
from scipy.sparse import issparse, vstack
|
||||
from scipy.stats import multivariate_normal
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
import quapy as qp
|
||||
from Transduction_office.grid_naive_quantif import GridQuantifier, binned_indexer, Indexer, GridQuantifier2, \
|
||||
classifier_indexer
|
||||
from method.non_aggregative import MLPE
|
||||
from quapy.protocol import AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol, UPP
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import *
|
||||
import quapy.functional as F
|
||||
from time import time
|
||||
from scipy.spatial.distance import cdist
|
||||
from Transduction.pykliep import DensityRatioEstimator
|
||||
from quapy.protocol import AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol
|
||||
from quapy.method.aggregative import *
|
||||
import quapy.functional as F
|
||||
|
||||
|
||||
plottting = False
|
||||
|
||||
|
||||
def gaussian(mean, cov=0.1, label=0, size=100, random_state=0):
|
||||
"""
|
||||
Creates a label collection in which the instances are distributed according to a Gaussian with specified
|
||||
parameters and labels all data points with a specific label.
|
||||
|
||||
:param mean: ndarray of shape (n_dimensions) with the center
|
||||
:param cov: ndarray of shape (n_dimensions, n_dimensions) with the covariance matrix, or a number for np.eye
|
||||
:param label: the class label for the collection
|
||||
:param size: number of instances
|
||||
:param random_state: allows for replicating experiments
|
||||
:return: an instance of LabelledCollection
|
||||
"""
|
||||
mean = np.asarray(mean)
|
||||
assert mean.ndim==1, 'wrong shape for mean'
|
||||
n_features = mean.shape[0]
|
||||
if isinstance(cov, (int, float)):
|
||||
cov = np.eye(n_features) * cov
|
||||
instances = multivariate_normal.rvs(mean, cov, size, random_state=random_state)
|
||||
return LabelledCollection(instances, labels=[label]*size)
|
||||
|
||||
|
||||
def _internal_plot(train, val, test):
|
||||
if plottting:
|
||||
xmin = min(train.X[:, 0].min(), val.X[:, 0].min(), test[:, 0].min())
|
||||
xmax = max(train.X[:, 0].max(), val.X[:, 0].max(), test[:, 0].max())
|
||||
ymin = min(train.X[:, 1].min(), val.X[:, 1].min(), test[:, 1].min())
|
||||
ymax = max(train.X[:, 1].max(), val.X[:, 1].max(), test[:, 1].max())
|
||||
plot(train, 'sel_train.png', xlim=(xmin, xmax), ylim=(ymin, ymax))
|
||||
plot(val, 'sel_val.png', xlim=(xmin, xmax), ylim=(ymin, ymax))
|
||||
plot(test, 'test.png', xlim=(xmin, xmax), ylim=(ymin, ymax))
|
||||
|
||||
def plot(data: LabelledCollection, path, xlim=None, ylim=None):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.clf()
|
||||
if isinstance(data, LabelledCollection):
|
||||
if data.instances.shape[1] != 2:
|
||||
return
|
||||
|
||||
negative, positive = data.separate()
|
||||
plt.scatter(negative.X[:,0], negative.X[:,1], label='neg', alpha=0.5)
|
||||
plt.scatter(positive.X[:, 0], positive.X[:, 1], label='pos', alpha=0.5)
|
||||
else:
|
||||
if data.shape[1] != 2:
|
||||
return
|
||||
plt.scatter(data[:, 0], data[:, 1], label='test', alpha=0.5)
|
||||
if xlim is not None:
|
||||
plt.xlim(*xlim)
|
||||
plt.ylim(*ylim)
|
||||
plt.legend()
|
||||
plt.savefig(path)
|
||||
|
||||
# ------------------------------------------------------------------------------------
|
||||
# Protocol for generating prior probability shift + covariate shift by mixing "domains"
|
||||
# ------------------------------------------------------------------------------------
|
||||
class CovPriorShift(AbstractStochasticSeededProtocol):
|
||||
|
||||
def __init__(self, domains: Iterable[LabelledCollection], sample_size=None, repeats=100, min_support=0, random_state=0,
|
||||
return_type='sample_prev'):
|
||||
super(CovPriorShift, self).__init__(random_state)
|
||||
self.domains = list(itertools.chain.from_iterable(lc.separate() for lc in domains))
|
||||
self.sample_size = qp._get_sample_size(sample_size)
|
||||
self.repeats = repeats
|
||||
self.min_support = min_support
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def samples_parameters(self):
|
||||
"""
|
||||
Return all the necessary parameters to replicate the samples as according to the UPP protocol.
|
||||
|
||||
:return: a list of indexes that realize the UPP sampling
|
||||
"""
|
||||
indexes = []
|
||||
tentatives = 0
|
||||
while len(indexes) < self.repeats:
|
||||
alpha = F.uniform_simplex_sampling(n_classes=len(self.domains))
|
||||
sizes = (alpha * self.sample_size).astype(int)
|
||||
if all(sizes > self.min_support):
|
||||
indexes_i = [lc.sampling_index(size) for lc, size in zip(self.domains, sizes)]
|
||||
indexes.append(indexes_i)
|
||||
tentatives = 0
|
||||
else:
|
||||
tentatives += 1
|
||||
if tentatives > 100:
|
||||
raise ValueError('the support is too strict, and it is difficult '
|
||||
'or impossible to generate valid samples')
|
||||
return indexes
|
||||
|
||||
def sample(self, params):
|
||||
indexes = params
|
||||
lcs = [lc.sampling_from_index(index) for index, lc in zip(indexes, self.domains)]
|
||||
return LabelledCollection.join(*lcs)
|
||||
|
||||
def total(self):
|
||||
"""
|
||||
Returns the number of samples that will be generated
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self.repeats
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------------------
|
||||
# Methods of "importance weight", e.g., by ratio density estimation (KLIEP, SILF, LogReg)
|
||||
# ---------------------------------------------------------------------------------------
|
||||
class ImportanceWeight:
|
||||
@abstractmethod
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
pass
|
||||
|
||||
|
||||
class KLIEP(ImportanceWeight):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
kliep = DensityRatioEstimator()
|
||||
kliep.fit(Xtr, Xte)
|
||||
return kliep.predict(Xtr)
|
||||
|
||||
|
||||
class USILF(ImportanceWeight):
|
||||
|
||||
def __init__(self, alpha=0.):
|
||||
self.alpha = alpha
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
dense_ratio_obj = densratio(Xtr, Xte, alpha=self.alpha, verbose=False)
|
||||
return dense_ratio_obj.compute_density_ratio(Xtr)
|
||||
|
||||
|
||||
class LogReg(ImportanceWeight):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
# check "Direct Density Ratio Estimation for
|
||||
# Large-scale Covariate Shift Adaptation", Eq.28
|
||||
|
||||
if issparse(Xtr):
|
||||
X = vstack([Xtr, Xte])
|
||||
else:
|
||||
X = np.concatenate([Xtr, Xte])
|
||||
|
||||
y = [0]*len(Xtr) + [1]*len(Xte)
|
||||
|
||||
logreg = GridSearchCV(
|
||||
LogisticRegression(),
|
||||
param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
|
||||
n_jobs=-1
|
||||
)
|
||||
logreg.fit(X, y)
|
||||
prob_train = logreg.predict_proba(Xtr)[:,0]
|
||||
prob_test = logreg.predict_proba(Xtr)[:,1]
|
||||
prior_train = len(Xtr)
|
||||
prior_test = len(Xte)
|
||||
w = (prior_train/prior_test)*(prob_test/prob_train)
|
||||
return w
|
||||
|
||||
|
||||
class MostTest(ImportanceWeight):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
# check "Direct Density Ratio Estimation for
|
||||
# Large-scale Covariate Shift Adaptation", Eq.28
|
||||
|
||||
if issparse(Xtr):
|
||||
X = vstack([Xtr, Xte])
|
||||
else:
|
||||
X = np.concatenate([Xtr, Xte])
|
||||
|
||||
y = [0]*len(Xtr) + [1]*len(Xte)
|
||||
|
||||
logreg = GridSearchCV(
|
||||
LogisticRegression(),
|
||||
param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
|
||||
n_jobs=-1
|
||||
)
|
||||
# logreg = LogisticRegression()
|
||||
# logreg.fit(X, y)
|
||||
# prob_test = logreg.predict_proba(Xtr)[:,1]
|
||||
prob_test = cross_val_predict(logreg, X, y, n_jobs=-1, method="predict_proba")[:len(Xtr),1]
|
||||
return prob_test
|
||||
|
||||
|
||||
class Random(ImportanceWeight):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
return np.random.rand(len(Xtr))
|
||||
|
||||
|
||||
class MostSimilarK(ImportanceWeight):
|
||||
# retains the training documents that are most similar in average to the k closest test points
|
||||
|
||||
def __init__(self, k):
|
||||
self.k = k
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
distances = cdist(Xtr, Xte)
|
||||
min_dist = np.min(distances)
|
||||
max_dist = np.max(distances)
|
||||
distances = (distances-min_dist)/(max_dist-min_dist)
|
||||
similarities = 1 / (1+distances)
|
||||
top_k_sim = np.sort(similarities, axis=1)[:,-self.k:]
|
||||
ave_sim = np.mean(top_k_sim, axis=1)
|
||||
return ave_sim
|
||||
|
||||
class MostSimilarTest(ImportanceWeight):
|
||||
# retains the training documents that are the most similar to one test document
|
||||
# i.e., for each test point, selects the K most similar train instances
|
||||
|
||||
def __init__(self, k=1):
|
||||
self.k = k
|
||||
|
||||
def weights(self, Xtr, ytr, Xte):
|
||||
distances = cdist(Xtr, Xte)
|
||||
most_similar_idx = np.argsort(distances, axis=0)[:self.k, :].flatten()
|
||||
weights = np.zeros(shape=Xtr.shape[0])
|
||||
weights[most_similar_idx] = 1
|
||||
return weights
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Quantification Methods that rely on Importance Weight for reweighting the training instances
|
||||
# --------------------------------------------------------------------------------------------
|
||||
class TransductiveQuantifier(BaseQuantifier):
|
||||
|
||||
def fit(self, data: LabelledCollection):
|
||||
self.training_ = data
|
||||
return self
|
||||
|
||||
@property
|
||||
def training(self):
|
||||
return self.training_
|
||||
|
||||
|
||||
class ReweightingAggregative(TransductiveQuantifier):
|
||||
|
||||
def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=CC):
|
||||
self.classifier = classifier
|
||||
self.weighter = weighter
|
||||
self.quantif_method = quantif_method
|
||||
|
||||
def quantify(self, instances):
|
||||
# time_weight = 2.95340 time_train = 0.00619
|
||||
w = self.weighter.weights(*self.training.Xy, instances)
|
||||
self.classifier.fit(*self.training.Xy, sample_weight=w)
|
||||
quantifier = self.quantif_method(self.classifier).fit(self.training, fit_classifier=False)
|
||||
return quantifier.quantify(instances)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Quantification Methods that rely on Importance Weight for selecting a validation partition
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
class SelectorQuantifiersTrainVal(TransductiveQuantifier):
|
||||
|
||||
def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=ACC, val_split=0.4, only_positives=False):
|
||||
self.classifier = classifier
|
||||
self.weighter = weighter
|
||||
self.quantif_method = quantif_method
|
||||
self.val_split = val_split
|
||||
self.only_positives = only_positives
|
||||
|
||||
def quantify(self, instances):
|
||||
w = self.weighter.weights(*self.training.Xy, instances)
|
||||
train, val = self.select_from_weights(w, self.training, self.val_split, self.only_positives)
|
||||
_internal_plot(train, val, instances)
|
||||
# print('\ttraining size', len(train), '\tval size', len(val))
|
||||
quantifier = self.quantif_method(self.classifier).fit(train, val_split=val)
|
||||
return quantifier.quantify(instances)
|
||||
|
||||
def select_from_weights(self, w, data: LabelledCollection, val_prop=0.4, only_positives=False):
|
||||
order = np.argsort(w)
|
||||
if only_positives:
|
||||
val_prop = np.mean(w > 0)
|
||||
split_point = int(len(w) * val_prop)
|
||||
different_idx, similar_idx = order[:-split_point], order[-split_point:]
|
||||
different, similar = data.sampling_from_index(different_idx), data.sampling_from_index(similar_idx)
|
||||
# return different, similar
|
||||
train, val = similar.split_stratified(0.6)
|
||||
return train, val
|
||||
|
||||
|
||||
class SelectorQuantifiersTrain(TransductiveQuantifier):
|
||||
|
||||
def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=ACC, only_positives=False):
|
||||
self.classifier = classifier
|
||||
self.weighter = weighter
|
||||
self.quantif_method = quantif_method
|
||||
self.only_positives = only_positives
|
||||
|
||||
def quantify(self, instances):
|
||||
w = self.weighter.weights(*self.training.Xy, instances)
|
||||
train = self.select_from_weights(w, self.training, select_prop=None, only_positives=self.only_positives)
|
||||
# _internal_plot(train, None, instances)
|
||||
# print('\ttraining size', len(train))
|
||||
quantifier = self.quantif_method(self.classifier).fit(train)
|
||||
return quantifier.quantify(instances)
|
||||
|
||||
def select_from_weights(self, w, data: LabelledCollection, select_prop=0.5, only_positives=False):
|
||||
order = np.argsort(w)
|
||||
if only_positives:
|
||||
select_prop = np.mean(w > 0)
|
||||
split_point = int(len(w) * select_prop)
|
||||
different_idx, similar_idx = order[:-split_point], order[-split_point:]
|
||||
different, similar = data.sampling_from_index(different_idx), data.sampling_from_index(similar_idx)
|
||||
return similar
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
qp.environ['SAMPLE_SIZE'] = 500
|
||||
|
||||
dA_l0 = gaussian(mean=[0,0], label=0, size=5000)
|
||||
dA_l1 = gaussian(mean=[1,0], label=1, size=5000)
|
||||
dB_l0 = gaussian(mean=[0,1], label=0, size=5000)
|
||||
dB_l1 = gaussian(mean=[1,1], label=1, size=5000)
|
||||
|
||||
dA = LabelledCollection.join(dA_l0, dA_l1)
|
||||
dB = LabelledCollection.join(dB_l0, dB_l1)
|
||||
|
||||
dA_train, dA_test = dA.split_stratified(0.5, random_state=0)
|
||||
dB_train, dB_test = dB.split_stratified(0.5, random_state=0)
|
||||
|
||||
train = LabelledCollection.join(dA_train, dB_train)
|
||||
|
||||
plot(train, 'train.png')
|
||||
|
||||
def lr():
|
||||
return LogisticRegression()
|
||||
|
||||
|
||||
|
||||
# EMQ.MAX_ITER*=10
|
||||
# val_split = 0.5
|
||||
k_sim = 10
|
||||
Q=ACC
|
||||
methods = [
|
||||
('MLPE', MLPE()),
|
||||
('CC', CC(lr())),
|
||||
('PCC', PCC(lr())),
|
||||
('ACC', ACC(lr())),
|
||||
('PACC', PACC(lr())),
|
||||
('HDy', HDy(lr())),
|
||||
('EMQ', EMQ(lr())),
|
||||
('GridQ', GridQuantifier2(classifier=lr())),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=2)), cell_quantifier=Q(lr()))),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=4)), cell_quantifier=Q(lr()))),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=6)), cell_quantifier=Q(lr()))),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=8)), cell_quantifier=Q(lr()))),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=10)), cell_quantifier=Q(lr()))),
|
||||
# ('GridQ', GridQuantifier(Indexer(binned_indexer(train.X, nbins_by_dim=20)), cell_quantifier=Q(lr()))),
|
||||
# ('kSim-ACC', SelectorQuantifiers(lr(), MostSimilar(k_sim), ACC, val_split=val_split)),
|
||||
# ('kSim-PACC', SelectorQuantifiers(lr(), MostSimilar(k_sim), PACC, val_split=val_split)),
|
||||
# ('kSim-HDy', SelectorQuantifiers(lr(), MostSimilar(k_sim), HDy, val_split=val_split)),
|
||||
# ('Sel-CC', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), CC, only_positives=True)),
|
||||
# ('Sel-PCC', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), PCC, only_positives=True)),
|
||||
# ('Sel-ACC', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), ACC, only_positives=True)),
|
||||
# ('Sel-PACC', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), PACC, only_positives=True)),
|
||||
# ('Sel-HDy', SelectorQuantifiersTrainVal(lr(), MostSimilarTest(k=k_sim), HDy, only_positives=True)),
|
||||
# ('Sel-EMQ', SelectorQuantifiersTrain(lr(), MostSimilarTest(k=k_sim), EMQ, only_positives=True)),
|
||||
# ('Sel-EMQ', SelectorQuantifiersTrainVal(lr(), USILF(), PACC, only_positives=False)),
|
||||
# ('Sel-PACC', SelectorQuantifiers(lr(), MostTest(), PACC)),
|
||||
# ('Sel-HDy', SelectorQuantifiers(lr(), MostTest(), HDy)),
|
||||
# ('LogReg-CC', ReweightingAggregative(lr(), LogReg(), CC)),
|
||||
# ('LogReg-PCC', ReweightingAggregative(lr(), LogReg(), PCC)),
|
||||
# ('LogReg-EMQ', ReweightingAggregative(lr(), LogReg(), EMQ)),
|
||||
# ('KLIEP-CC', ReweightingAggregative(lr(), KLIEP(), CC)),
|
||||
# ('KLIEP-PCC', ReweightingAggregative(lr(), KLIEP(), PCC)),
|
||||
# ('KLIEP-EMQ', ReweightingAggregative(lr(), KLIEP(), EMQ)),
|
||||
# ('SILF-CC', ReweightingAggregative(lr(), USILF(), CC)),
|
||||
# ('SILF-PCC', ReweightingAggregative(lr(), USILF(), PCC)),
|
||||
# ('SILF-EMQ', ReweightingAggregative(lr(), USILF(), EMQ))
|
||||
]
|
||||
|
||||
for name, model in methods:
|
||||
with qp.util.temp_seed(5):
|
||||
# print('original training size', len(train))
|
||||
model.fit(train)
|
||||
|
||||
prot = CovPriorShift([dA_test, dB_test], repeats=1 if plottting else 150)
|
||||
# prot = UPP(dA_test+dB_test, repeats=1 if plottting else 150)
|
||||
mae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mae')
|
||||
print(f'{name}: {mae = :.4f}')
|
||||
# mrae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mrae')
|
||||
# print(f'{name}: {mrae = :.4f}')
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
import numpy as np
|
||||
import warnings
|
||||
|
||||
|
||||
class DensityRatioEstimator:
|
||||
"""
|
||||
Class to accomplish direct density estimation implementing the original KLIEP
|
||||
algorithm from Direct Importance Estimation with Model Selection
|
||||
and Its Application to Covariate Shift Adaptation by Sugiyama et al.
|
||||
|
||||
The training set is distributed via
|
||||
train ~ p(x)
|
||||
and the test set is distributed via
|
||||
test ~ q(x).
|
||||
|
||||
The KLIEP algorithm and its variants approximate w(x) = q(x) / p(x) directly. The predict function returns the
|
||||
estimate of w(x). The function w(x) can serve as sample weights for the training set during
|
||||
training to modify the expectation function that the model's loss function is optimized via,
|
||||
i.e.
|
||||
|
||||
E_{x ~ w(x)p(x)} loss(x) = E_{x ~ q(x)} loss(x).
|
||||
|
||||
Usage :
|
||||
The fit method is used to run the KLIEP algorithm using LCV and returns value of J
|
||||
trained on the entire training/test set with the best sigma found.
|
||||
Use the predict method on the training set to determine the sample weights from the KLIEP algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter=5000, num_params=[.1, .2], epsilon=1e-4, cv=3, sigmas=[.01, .1, .25, .5, .75, 1],
|
||||
random_state=None, verbose=0):
|
||||
"""
|
||||
Direct density estimation using an inner LCV loop to estimate the proper model. Can be used with sklearn
|
||||
cross validation methods with or without storing the inner CV. To use a standard grid search.
|
||||
|
||||
|
||||
max_iter : Number of iterations to perform
|
||||
num_params : List of number of test set vectors used to construct the approximation for inner LCV.
|
||||
Must be a float. Original paper used 10%, i.e. =.1
|
||||
sigmas : List of sigmas to be used in inner LCV loop.
|
||||
epsilon : Additive factor in the iterative algorithm for numerical stability.
|
||||
"""
|
||||
self.max_iter = max_iter
|
||||
self.num_params = num_params
|
||||
self.epsilon = epsilon
|
||||
self.verbose = verbose
|
||||
self.sigmas = sigmas
|
||||
self.cv = cv
|
||||
self.random_state = 0
|
||||
|
||||
def fit(self, X_train, X_test, alpha_0=None):
|
||||
""" Uses cross validation to select sigma as in the original paper (LCV).
|
||||
In a break from sklearn convention, y=X_test.
|
||||
The parameter cv corresponds to R in the original paper.
|
||||
Once found, the best sigma is used to train on the full set."""
|
||||
|
||||
# LCV loop, shuffle a copy in place for performance.
|
||||
cv = self.cv
|
||||
chunk = int(X_test.shape[0] / float(cv))
|
||||
if self.random_state is not None:
|
||||
np.random.seed(self.random_state)
|
||||
X_test_shuffled = X_test.copy()
|
||||
np.random.shuffle(X_test_shuffled)
|
||||
|
||||
j_scores = {}
|
||||
|
||||
if type(self.sigmas) != list:
|
||||
self.sigmas = [self.sigmas]
|
||||
|
||||
if type(self.num_params) != list:
|
||||
self.num_params = [self.num_params]
|
||||
|
||||
if len(self.sigmas) * len(self.num_params) > 1:
|
||||
# Inner LCV loop
|
||||
for num_param in self.num_params:
|
||||
for sigma in self.sigmas:
|
||||
j_scores[(num_param, sigma)] = np.zeros(cv)
|
||||
for k in range(1, cv + 1):
|
||||
if self.verbose > 0:
|
||||
print('Training: sigma: %s R: %s' % (sigma, k))
|
||||
X_test_fold = X_test_shuffled[(k - 1) * chunk:k * chunk, :]
|
||||
j_scores[(num_param, sigma)][k - 1] = self._fit(X_train=X_train,
|
||||
X_test=X_test_fold,
|
||||
num_parameters=num_param,
|
||||
sigma=sigma)
|
||||
j_scores[(num_param, sigma)] = np.mean(j_scores[(num_param, sigma)])
|
||||
|
||||
sorted_scores = sorted([x for x in j_scores.items() if np.isfinite(x[1])], key=lambda x: x[1],
|
||||
reverse=True)
|
||||
if len(sorted_scores) == 0:
|
||||
warnings.warn('LCV failed to converge for all values of sigma.')
|
||||
return self
|
||||
self._sigma = sorted_scores[0][0][1]
|
||||
self._num_parameters = sorted_scores[0][0][0]
|
||||
self._j_scores = sorted_scores
|
||||
else:
|
||||
self._sigma = self.sigmas[0]
|
||||
self._num_parameters = self.num_params[0]
|
||||
# best sigma
|
||||
self._j = self._fit(X_train=X_train, X_test=X_test_shuffled, num_parameters=self._num_parameters,
|
||||
sigma=self._sigma)
|
||||
|
||||
return self # Compatibility with sklearn
|
||||
|
||||
def _fit(self, X_train, X_test, num_parameters, sigma, alpha_0=None):
|
||||
""" Fits the estimator with the given parameters w-hat and returns J"""
|
||||
|
||||
num_parameters = num_parameters
|
||||
|
||||
if type(num_parameters) == float:
|
||||
num_parameters = int(X_test.shape[0] * num_parameters)
|
||||
|
||||
self._select_param_vectors(X_test=X_test,
|
||||
sigma=sigma,
|
||||
num_parameters=num_parameters)
|
||||
|
||||
X_train = self._reshape_X(X_train)
|
||||
X_test = self._reshape_X(X_test)
|
||||
|
||||
if alpha_0 is None:
|
||||
alpha_0 = np.ones(shape=(num_parameters, 1)) / float(num_parameters)
|
||||
|
||||
self._find_alpha(X_train=X_train,
|
||||
X_test=X_test,
|
||||
num_parameters=num_parameters,
|
||||
epsilon=self.epsilon,
|
||||
alpha_0=alpha_0,
|
||||
sigma=sigma)
|
||||
|
||||
return self._calculate_j(X_test, sigma=sigma)
|
||||
|
||||
def _calculate_j(self, X_test, sigma):
|
||||
pred = self.predict(X_test, sigma=sigma)+0.0000001
|
||||
log = np.log(pred).sum()
|
||||
return log / (X_test.shape[0])
|
||||
|
||||
def score(self, X_test):
|
||||
""" Return the J score, similar to sklearn's API """
|
||||
return self._calculate_j(X_test=X_test, sigma=self._sigma)
|
||||
|
||||
@staticmethod
|
||||
def _reshape_X(X):
|
||||
""" Reshape input from mxn to mx1xn to take advantage of numpy broadcasting. """
|
||||
if len(X.shape) != 3:
|
||||
return X.reshape((X.shape[0], 1, X.shape[1]))
|
||||
return X
|
||||
|
||||
def _select_param_vectors(self, X_test, sigma, num_parameters):
|
||||
""" X_test is the test set. b is the number of parameters. """
|
||||
indices = np.random.choice(X_test.shape[0], size=num_parameters, replace=False)
|
||||
self._test_vectors = X_test[indices, :].copy()
|
||||
self._phi_fitted = True
|
||||
|
||||
def _phi(self, X, sigma=None):
|
||||
|
||||
if sigma is None:
|
||||
sigma = self._sigma
|
||||
|
||||
if self._phi_fitted:
|
||||
return np.exp(-np.sum((X - self._test_vectors) ** 2, axis=-1) / (2 * sigma ** 2))
|
||||
raise Exception('Phi not fitted.')
|
||||
|
||||
def _find_alpha(self, alpha_0, X_train, X_test, num_parameters, sigma, epsilon):
|
||||
A = np.zeros(shape=(X_test.shape[0], num_parameters))
|
||||
b = np.zeros(shape=(num_parameters, 1))
|
||||
|
||||
A = self._phi(X_test, sigma)
|
||||
b = self._phi(X_train, sigma).sum(axis=0) / X_train.shape[0]
|
||||
b = b.reshape((num_parameters, 1))
|
||||
|
||||
out = alpha_0.copy()
|
||||
for k in range(self.max_iter):
|
||||
mat = np.dot(A, out)
|
||||
mat += 0.000000001
|
||||
out += epsilon * np.dot(np.transpose(A), 1. / mat)
|
||||
out += b * (((1 - np.dot(np.transpose(b), out)) / np.dot(np.transpose(b), b)))
|
||||
out = np.maximum(0, out)
|
||||
out /= (np.dot(np.transpose(b), out))
|
||||
|
||||
self._alpha = out
|
||||
self._fitted = True
|
||||
|
||||
def predict(self, X, sigma=None):
|
||||
""" Equivalent of w(X) from the original paper."""
|
||||
|
||||
X = self._reshape_X(X)
|
||||
if not self._fitted:
|
||||
raise Exception('Not fitted!')
|
||||
return np.dot(self._phi(X, sigma=sigma), self._alpha).reshape((X.shape[0],))
|
|
@ -0,0 +1,152 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import quapy as qp
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from classification.methods import LowRankLogisticRegression
|
||||
from quapy.method.meta import QuaNet
|
||||
from quapy.protocol import APP
|
||||
from quapy.method.aggregative import CC, ACC, PCC, PACC, MAX, MS, MS2, EMQ, HDy, newSVMAE
|
||||
from quapy.method.meta import EHDy
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import itertools
|
||||
import argparse
|
||||
import torch
|
||||
import shutil
|
||||
|
||||
|
||||
N_JOBS = -1
|
||||
CUDA_N_JOBS = 2
|
||||
ENSEMBLE_N_JOBS = -1
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
|
||||
|
||||
def newLR():
|
||||
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
||||
|
||||
|
||||
def calibratedLR():
|
||||
return CalibratedClassifierCV(LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1))
|
||||
|
||||
|
||||
__C_range = np.logspace(-3, 3, 7)
|
||||
lr_params = {'classifier__C': __C_range, 'classifier__class_weight': [None, 'balanced']}
|
||||
svmperf_params = {'classifier__C': __C_range}
|
||||
|
||||
|
||||
def quantification_models():
|
||||
yield 'cc', CC(newLR()), lr_params
|
||||
yield 'acc', ACC(newLR()), lr_params
|
||||
yield 'pcc', PCC(newLR()), lr_params
|
||||
yield 'pacc', PACC(newLR()), lr_params
|
||||
yield 'MAX', MAX(newLR()), lr_params
|
||||
yield 'MS', MS(newLR()), lr_params
|
||||
yield 'MS2', MS2(newLR()), lr_params
|
||||
yield 'sldc', EMQ(newLR(), recalib='platt'), lr_params
|
||||
yield 'svmmae', newSVMAE(), svmperf_params
|
||||
yield 'hdy', HDy(newLR()), lr_params
|
||||
|
||||
|
||||
def quantification_cuda_models():
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
learner = LowRankLogisticRegression()
|
||||
yield 'quanet', QuaNet(learner, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
|
||||
|
||||
def evaluate_experiment(true_prevalences, estim_prevalences):
|
||||
print('\nEvaluation Metrics:\n' + '=' * 22)
|
||||
for eval_measure in [qp.error.mae, qp.error.mrae]:
|
||||
err = eval_measure(true_prevalences, estim_prevalences)
|
||||
print(f'\t{eval_measure.__name__}={err:.4f}')
|
||||
print()
|
||||
|
||||
|
||||
def result_path(path, dataset_name, model_name, run, optim_loss):
|
||||
return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl')
|
||||
|
||||
|
||||
def is_already_computed(dataset_name, model_name, run, optim_loss):
|
||||
return os.path.exists(result_path(args.results, dataset_name, model_name, run, optim_loss))
|
||||
|
||||
|
||||
def save_results(dataset_name, model_name, run, optim_loss, *results):
|
||||
rpath = result_path(args.results, dataset_name, model_name, run, optim_loss)
|
||||
qp.util.create_parent_dir(rpath)
|
||||
with open(rpath, 'wb') as foo:
|
||||
pickle.dump(tuple(results), foo, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def run(experiment):
|
||||
optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
|
||||
if dataset_name in ['acute.a', 'acute.b', 'iris.1']: return
|
||||
|
||||
collection = qp.datasets.fetch_UCILabelledCollection(dataset_name)
|
||||
for run, data in enumerate(qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=1)):
|
||||
if is_already_computed(dataset_name, model_name, run=run, optim_loss=optim_loss):
|
||||
print(f'result for dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5 already computed.')
|
||||
continue
|
||||
|
||||
print(f'running dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5')
|
||||
# model selection (hyperparameter optimization for a quantification-oriented loss)
|
||||
train, test = data.train_test
|
||||
train, val = train.split_stratified()
|
||||
if hyperparams is not None:
|
||||
model_selection = qp.model_selection.GridSearchQ(
|
||||
deepcopy(model),
|
||||
param_grid=hyperparams,
|
||||
protocol=APP(val, n_prevalences=21, repeats=25),
|
||||
error=optim_loss,
|
||||
refit=True,
|
||||
timeout=60*60,
|
||||
verbose=True
|
||||
)
|
||||
model_selection.fit(data.training)
|
||||
model = model_selection.best_model()
|
||||
best_params = model_selection.best_params_
|
||||
else:
|
||||
model.fit(data.training)
|
||||
best_params = {}
|
||||
|
||||
# model evaluation
|
||||
true_prevalences, estim_prevalences = qp.evaluation.prediction(
|
||||
model,
|
||||
protocol=APP(test, n_prevalences=21, repeats=100)
|
||||
)
|
||||
test_true_prevalence = data.test.prevalence()
|
||||
|
||||
evaluate_experiment(true_prevalences, estim_prevalences)
|
||||
save_results(dataset_name, model_name, run, optim_loss,
|
||||
true_prevalences, estim_prevalences,
|
||||
data.training.prevalence(), test_true_prevalence,
|
||||
best_params)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
|
||||
parser.add_argument('results', metavar='RESULT_PATH', type=str,
|
||||
help='path to the directory where to store the results')
|
||||
parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='../svm_perf_quantification',
|
||||
help='path to the directory with svmperf')
|
||||
parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint',
|
||||
help='path to the directory where to dump QuaNet checkpoints')
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f'Result folder: {args.results}')
|
||||
np.random.seed(0)
|
||||
|
||||
qp.environ['SVMPERF_HOME'] = args.svmperfpath
|
||||
|
||||
optim_losses = ['mae']
|
||||
datasets = qp.datasets.UCI_DATASETS[:4]
|
||||
|
||||
models = quantification_models()
|
||||
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=N_JOBS)
|
||||
|
||||
models = quantification_cuda_models()
|
||||
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS)
|
||||
|
||||
shutil.rmtree(args.checkpointdir, ignore_errors=True)
|
|
@ -19,7 +19,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
|
||||
def __init__(self, n_components=100, **kwargs):
|
||||
self.n_components = n_components
|
||||
self.learner = LogisticRegression(**kwargs)
|
||||
self.classifier = LogisticRegression(**kwargs)
|
||||
|
||||
def get_params(self):
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
:return: a dictionary with parameter names mapped to their values
|
||||
"""
|
||||
params = {'n_components': self.n_components}
|
||||
params.update(self.learner.get_params())
|
||||
params.update(self.classifier.get_params())
|
||||
return params
|
||||
|
||||
def set_params(self, **params):
|
||||
|
@ -43,7 +43,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
if 'n_components' in params_:
|
||||
self.n_components = params_['n_components']
|
||||
del params_['n_components']
|
||||
self.learner.set_params(**params_)
|
||||
self.classifier.set_params(**params_)
|
||||
|
||||
def fit(self, X, y):
|
||||
"""
|
||||
|
@ -59,8 +59,8 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
if nF > self.n_components:
|
||||
self.pca = TruncatedSVD(self.n_components).fit(X)
|
||||
X = self.transform(X)
|
||||
self.learner.fit(X, y)
|
||||
self.classes_ = self.learner.classes_
|
||||
self.classifier.fit(X, y)
|
||||
self.classes_ = self.classifier.classes_
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
|
@ -72,7 +72,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
instances in `X`
|
||||
"""
|
||||
X = self.transform(X)
|
||||
return self.learner.predict(X)
|
||||
return self.classifier.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
"""
|
||||
|
@ -82,7 +82,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
:return: array-like of shape `(n_samples, n_classes)` with the posterior probabilities
|
||||
"""
|
||||
X = self.transform(X)
|
||||
return self.learner.predict_proba(X)
|
||||
return self.classifier.predict_proba(X)
|
||||
|
||||
def transform(self, X):
|
||||
"""
|
||||
|
|
|
@ -322,6 +322,22 @@ class LabelledCollection:
|
|||
classes = np.unique(labels).sort()
|
||||
return LabelledCollection(instances, labels, classes=classes)
|
||||
|
||||
def separate(self):
|
||||
"""
|
||||
Breaks down this labelled collection into a list of labelled collections such that each element in the list
|
||||
contains all instances from a different class. The order in the list is consistent with the order in
|
||||
`self.classes_`. If some class has 0 elements, then None will be returned in that position in the list.
|
||||
|
||||
:return: list `L` of :class:`LabelledCollection` with `len(L)==len(self.classes_)`
|
||||
"""
|
||||
lcs = []
|
||||
for class_label in self.classes_:
|
||||
instances = self.instances[self.labels == class_label]
|
||||
n_instances = len(instances)
|
||||
new_lc = LabelledCollection(instances, [class_label]*n_instances) if (n_instances > 0) else None
|
||||
lcs.append(new_lc)
|
||||
return lcs
|
||||
|
||||
@property
|
||||
def Xy(self):
|
||||
"""
|
||||
|
|
|
@ -207,7 +207,7 @@ def fetch_UCIDataset(dataset_name, data_home=None, test_split=0.3, verbose=False
|
|||
return Dataset(*data.split_stratified(1 - test_split, random_state=0))
|
||||
|
||||
|
||||
def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) -> Dataset:
|
||||
def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) -> LabelledCollection:
|
||||
"""
|
||||
Loads a UCI collection as an instance of :class:`quapy.data.base.LabelledCollection`, as used in
|
||||
`Pérez-Gállego, P., Quevedo, J. R., & del Coz, J. J. (2017).
|
||||
|
@ -223,7 +223,7 @@ def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) ->
|
|||
|
||||
>>> import quapy as qp
|
||||
>>> collection = qp.datasets.fetch_UCILabelledCollection("yeast")
|
||||
>>> for data in qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
|
||||
>>> for data in qp.domains.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
|
||||
>>> ...
|
||||
|
||||
The list of valid dataset names can be accessed in `quapy.data.datasets.UCI_DATASETS`
|
||||
|
@ -233,7 +233,7 @@ def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) ->
|
|||
~/quay_data/ directory)
|
||||
:param test_split: proportion of documents to be included in the test set. The rest conforms the training set
|
||||
:param verbose: set to True (default is False) to get information (from the UCI ML repository) about the datasets
|
||||
:return: a :class:`quapy.data.base.Dataset` instance
|
||||
:return: a :class:`quapy.data.base.LabelledCollection` instance
|
||||
"""
|
||||
|
||||
assert dataset_name in UCI_DATASETS, \
|
||||
|
|
|
@ -444,24 +444,28 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
|
||||
def __init__(self, classifier: BaseEstimator, exact_train_prev=True, recalib=None):
|
||||
self.classifier = classifier
|
||||
self.non_calibrated = classifier
|
||||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
if self.recalib is not None:
|
||||
if self.recalib == 'nbvs':
|
||||
self.classifier = NBVSCalibration(self.classifier)
|
||||
self.classifier = NBVSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'bcts':
|
||||
self.classifier = BCTSCalibration(self.classifier)
|
||||
self.classifier = BCTSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'ts':
|
||||
self.classifier = TSCalibration(self.classifier)
|
||||
self.classifier = TSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'vs':
|
||||
self.classifier = VSCalibration(self.classifier)
|
||||
self.classifier = VSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'platt':
|
||||
self.classifier = CalibratedClassifierCV(self.classifier, ensemble=False)
|
||||
else:
|
||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||
'"nbvs", "bcts", "ts", and "vs".')
|
||||
self.recalib = None
|
||||
else:
|
||||
self.classifier = self.non_calibrated
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
if self.exact_train_prev:
|
||||
self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_)
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.nn.functional import relu
|
|||
from quapy.protocol import UPP
|
||||
from quapy.method.aggregative import *
|
||||
from quapy.util import EarlyStop
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class QuaNetTrainer(BaseQuantifier):
|
||||
|
@ -28,7 +29,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
>>>
|
||||
>>> # load the kindle dataset as text, and convert words to numerical indexes
|
||||
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
|
||||
>>> qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||
>>> qp.domains.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||
>>>
|
||||
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
|
||||
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
|
||||
|
@ -263,15 +264,19 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return {**self.classifier.get_params(), **self.quanet_params}
|
||||
classifier_params = self.classifier.get_params()
|
||||
classifier_params = {'classifier__'+k:v for k,v in classifier_params.items()}
|
||||
return {**classifier_params, **self.quanet_params}
|
||||
|
||||
def set_params(self, **parameters):
|
||||
learner_params = {}
|
||||
for key, val in parameters.items():
|
||||
if key in self.quanet_params:
|
||||
self.quanet_params[key] = val
|
||||
elif key.startswith('classifier__'):
|
||||
learner_params[key.replace('classifier__', '')] = val
|
||||
else:
|
||||
learner_params[key] = val
|
||||
raise ValueError('unknown parameter ', key)
|
||||
self.classifier.set_params(**learner_params)
|
||||
|
||||
def __check_params_colision(self, quanet_params, learner_params):
|
||||
|
|
|
@ -33,3 +33,5 @@ class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
|
|||
"""
|
||||
return self.estimated_prevalence
|
||||
|
||||
|
||||
MLPE = MaximumLikelihoodPrevalenceEstimation
|
|
@ -56,7 +56,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
def _sout(self, msg):
|
||||
if self.verbose:
|
||||
print(f'[{self.__class__.__name__}]: {msg}')
|
||||
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
|
||||
|
||||
def __check_error(self, error):
|
||||
if error in qp.error.QUANTIFICATION_ERROR:
|
||||
|
|
|
@ -9,9 +9,9 @@ import math
|
|||
|
||||
import quapy as qp
|
||||
|
||||
plt.rcParams['figure.figsize'] = [12, 8]
|
||||
plt.rcParams['figure.figsize'] = [10, 6]
|
||||
plt.rcParams['figure.dpi'] = 200
|
||||
plt.rcParams['font.size'] = 16
|
||||
plt.rcParams['font.size'] = 18
|
||||
|
||||
|
||||
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,
|
||||
|
|
|
@ -218,7 +218,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
to "labelled_collection" to get instead instances of LabelledCollection
|
||||
"""
|
||||
|
||||
def __init__(self, data:LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
|
||||
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
|
||||
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'):
|
||||
super(APP, self).__init__(random_state)
|
||||
self.data = data
|
||||
|
|
Loading…
Reference in New Issue