lipton bbse imported

This commit is contained in:
Lorenzo Volpi 2023-09-22 01:40:36 +02:00
parent 4d28c8eccf
commit ede348ea27
2 changed files with 99 additions and 10 deletions

72
lipton_bbse/labelshift.py Normal file
View File

@ -0,0 +1,72 @@
import numpy as np
#---------------------- utility functions used ----------------------------
def idx2onehot(a,k):
a=a.astype(int)
b = np.zeros((a.size, k))
b[np.arange(a.size), a] = 1
return b
def confusion_matrix(ytrue, ypred,k):
# C[i,j] denotes the frequency of ypred = i, ytrue = j.
n = ytrue.size
C = np.dot(idx2onehot(ypred,k).T,idx2onehot(ytrue,k))
return C/n
def confusion_matrix_probabilistic(ytrue, ypred,k):
# Input is probabilistic classifiers in forms of n by k matrices
n,d = np.shape(ypred)
C = np.dot(ypred.T, idx2onehot(ytrue,k))
return C/n
def calculate_marginal(y,k):
mu = np.zeros(shape=(k,1))
for i in range(k):
mu[i] = np.count_nonzero(y == i)
return mu/np.size(y)
def calculate_marginal_probabilistic(y,k):
return np.mean(y,axis=0)
def estimate_labelshift_ratio(ytrue_s, ypred_s, ypred_t,k):
if ypred_s.ndim == 2: # this indicates that it is probabilistic
C = confusion_matrix_probabilistic(ytrue_s,ypred_s,k)
mu_t = calculate_marginal_probabilistic(ypred_t, k)
else:
C = confusion_matrix(ytrue_s, ypred_s,k)
mu_t = calculate_marginal(ypred_t, k)
lamb = (1/min(len(ypred_s),len(ypred_t)))
wt = np.linalg.solve(np.dot(C.T, C)+lamb*np.eye(k), np.dot(C.T, mu_t))
return wt
def estimate_target_dist(wt, ytrue_s,k):
''' Input:
- wt: This is the output of estimate_labelshift_ratio)
- ytrue_s: This is the list of true labels from validation set
Output:
- An estimation of the true marginal distribution of the target set.
'''
mu_t = calculate_marginal(ytrue_s,k)
return wt*mu_t
# functions that convert beta to w and converge w to a corresponding weight function.
def beta_to_w(beta, y, k):
w = []
for i in range(k):
w.append(np.mean(beta[y.astype(int) == i]))
w = np.array(w)
return w
# a function that converts w to beta.
def w_to_beta(w,y):
return w[y.astype(int)]
def w_to_weightfunc(w):
return lambda x, y: w[y.astype(int)]
#----------------------------------------------------------------------------

View File

@ -1,16 +1,18 @@
from statistics import mean
from typing import Dict, assert_type
from unittest.mock import Base
from sklearn import clone
from typing import Dict
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate
from quapy.data import LabelledCollection
from elsahar19.rca import clone_fit
import garg22_ATC.ATC_helper as atc
import numpy as np
import jiang18_trustscore.trustscore as trustscore
import guillory21_doc.doc as doc
import elsahar19_rca.rca as rca
import garg22_ATC.ATC_helper as atc
import guillory21_doc.doc as doc
import jiang18_trustscore.trustscore as trustscore
import lipton_bbse.labelshift as bbse
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
scoring = ["f1_macro"]
@ -146,3 +148,18 @@ def rca_star_score(
return rca.get_score(val2_pred1, val2_pred2, validation2.y)
def bbse_score(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
val_probs, val_labels = c_model_predict(validation.X), validation.y
test_probs = c_model_predict(test.X)
wt = bbse.estimate_labelshift_ratio(val_labels, val_probs, test_probs, 2)
estim_prev = bbse.estimate_target_dist(wt, val_labels, 2)
true_prev = test.prevalence()
return qp.error.ae(true_prev, estim_prev)