lipton bbse imported
This commit is contained in:
parent
4d28c8eccf
commit
ede348ea27
|
|
@ -0,0 +1,72 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
#---------------------- utility functions used ----------------------------
|
||||||
|
def idx2onehot(a,k):
|
||||||
|
a=a.astype(int)
|
||||||
|
b = np.zeros((a.size, k))
|
||||||
|
b[np.arange(a.size), a] = 1
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
def confusion_matrix(ytrue, ypred,k):
|
||||||
|
# C[i,j] denotes the frequency of ypred = i, ytrue = j.
|
||||||
|
n = ytrue.size
|
||||||
|
C = np.dot(idx2onehot(ypred,k).T,idx2onehot(ytrue,k))
|
||||||
|
return C/n
|
||||||
|
|
||||||
|
def confusion_matrix_probabilistic(ytrue, ypred,k):
|
||||||
|
# Input is probabilistic classifiers in forms of n by k matrices
|
||||||
|
n,d = np.shape(ypred)
|
||||||
|
C = np.dot(ypred.T, idx2onehot(ytrue,k))
|
||||||
|
return C/n
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_marginal(y,k):
|
||||||
|
mu = np.zeros(shape=(k,1))
|
||||||
|
for i in range(k):
|
||||||
|
mu[i] = np.count_nonzero(y == i)
|
||||||
|
return mu/np.size(y)
|
||||||
|
|
||||||
|
def calculate_marginal_probabilistic(y,k):
|
||||||
|
return np.mean(y,axis=0)
|
||||||
|
|
||||||
|
def estimate_labelshift_ratio(ytrue_s, ypred_s, ypred_t,k):
|
||||||
|
if ypred_s.ndim == 2: # this indicates that it is probabilistic
|
||||||
|
C = confusion_matrix_probabilistic(ytrue_s,ypred_s,k)
|
||||||
|
mu_t = calculate_marginal_probabilistic(ypred_t, k)
|
||||||
|
else:
|
||||||
|
C = confusion_matrix(ytrue_s, ypred_s,k)
|
||||||
|
mu_t = calculate_marginal(ypred_t, k)
|
||||||
|
lamb = (1/min(len(ypred_s),len(ypred_t)))
|
||||||
|
wt = np.linalg.solve(np.dot(C.T, C)+lamb*np.eye(k), np.dot(C.T, mu_t))
|
||||||
|
return wt
|
||||||
|
|
||||||
|
def estimate_target_dist(wt, ytrue_s,k):
|
||||||
|
''' Input:
|
||||||
|
- wt: This is the output of estimate_labelshift_ratio)
|
||||||
|
- ytrue_s: This is the list of true labels from validation set
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- An estimation of the true marginal distribution of the target set.
|
||||||
|
'''
|
||||||
|
mu_t = calculate_marginal(ytrue_s,k)
|
||||||
|
return wt*mu_t
|
||||||
|
|
||||||
|
# functions that convert beta to w and converge w to a corresponding weight function.
|
||||||
|
def beta_to_w(beta, y, k):
|
||||||
|
w = []
|
||||||
|
for i in range(k):
|
||||||
|
w.append(np.mean(beta[y.astype(int) == i]))
|
||||||
|
w = np.array(w)
|
||||||
|
return w
|
||||||
|
|
||||||
|
# a function that converts w to beta.
|
||||||
|
def w_to_beta(w,y):
|
||||||
|
return w[y.astype(int)]
|
||||||
|
|
||||||
|
def w_to_weightfunc(w):
|
||||||
|
return lambda x, y: w[y.astype(int)]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
|
@ -1,16 +1,18 @@
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import Dict, assert_type
|
from typing import Dict
|
||||||
from unittest.mock import Base
|
|
||||||
from sklearn import clone
|
import numpy as np
|
||||||
|
import quapy as qp
|
||||||
|
from quapy.data import LabelledCollection
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.model_selection import cross_validate
|
from sklearn.model_selection import cross_validate
|
||||||
from quapy.data import LabelledCollection
|
|
||||||
from elsahar19.rca import clone_fit
|
|
||||||
import garg22_ATC.ATC_helper as atc
|
|
||||||
import numpy as np
|
|
||||||
import jiang18_trustscore.trustscore as trustscore
|
|
||||||
import guillory21_doc.doc as doc
|
|
||||||
import elsahar19_rca.rca as rca
|
import elsahar19_rca.rca as rca
|
||||||
|
import garg22_ATC.ATC_helper as atc
|
||||||
|
import guillory21_doc.doc as doc
|
||||||
|
import jiang18_trustscore.trustscore as trustscore
|
||||||
|
import lipton_bbse.labelshift as bbse
|
||||||
|
|
||||||
|
|
||||||
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
||||||
scoring = ["f1_macro"]
|
scoring = ["f1_macro"]
|
||||||
|
|
@ -145,4 +147,19 @@ def rca_star_score(
|
||||||
|
|
||||||
return rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
return rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
||||||
|
|
||||||
|
|
||||||
|
def bbse_score(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
test: LabelledCollection,
|
||||||
|
predict_method="predict_proba",
|
||||||
|
):
|
||||||
|
|
||||||
|
c_model_predict = getattr(c_model, predict_method)
|
||||||
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
||||||
|
test_probs = c_model_predict(test.X)
|
||||||
|
|
||||||
|
wt = bbse.estimate_labelshift_ratio(val_labels, val_probs, test_probs, 2)
|
||||||
|
estim_prev = bbse.estimate_target_dist(wt, val_labels, 2)
|
||||||
|
true_prev = test.prevalence()
|
||||||
|
return qp.error.ae(true_prev, estim_prev)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue