from typing import List import numpy as np import scipy.sparse as sp from sklearn.linear_model import LinearRegression import baselines.atc as atc __confs = {} def metric(name): def wrapper(cl): __confs[name] = cl return cl return wrapper class ConfidenceMetric: def fit(self, X, y, probas): pass def conf(self, X, probas): return probas @metric("max_conf") class MaxConf(ConfidenceMetric): def conf(self, X, probas): _mc = np.max(probas, axis=1, keepdims=True) return _mc @metric("entropy") class Entropy(ConfidenceMetric): def conf(self, X, probas): _ent = np.sum( np.multiply(probas, np.log(probas + 1e-20)), axis=1, keepdims=True ) return _ent @metric("isoft") class InverseSoftmax(ConfidenceMetric): def conf(self, X, probas): _probas = probas / np.sum(probas, axis=1, keepdims=True) _probas = np.log(_probas) - np.mean(np.log(_probas), axis=1, keepdims=True) return np.max(_probas, axis=1, keepdims=True) @metric("threshold") class Threshold(ConfidenceMetric): def get_scores(self, probas, keepdims=False): return np.max(probas, axis=1, keepdims=keepdims) def fit(self, X, y, probas): scores = self.get_scores(probas) _, self.threshold = atc.find_ATC_threshold(scores, y) def conf(self, X, probas): scores = self.get_scores(probas, keepdims=True) _exp = scores - self.threshold return _exp # def conf(self, X, probas): # scores = self.get_scores(probas) # _exp = np.where( # scores >= self.threshold, np.ones(scores.shape), np.zeros(scores.shape) # ) # return _exp[:, np.newaxis] @metric("linreg") class LinReg(ConfidenceMetric): def extend(self, X, probas): if sp.issparse(X): return sp.hstack([X, probas]) else: return np.concatenate([X, probas], axis=1) def fit(self, X, y, probas): reg_X = self.extend(X, probas) reg_y = probas[np.arange(probas.shape[0]), y] self.reg = LinearRegression() self.reg.fit(reg_X, reg_y) def conf(self, X, probas): reg_X = self.extend(X, probas) return self.reg.predict(reg_X)[:, np.newaxis] def get_metrics(names: List[str]): if names is None: return None __fnames = [n for n in names if n in __confs] return [__confs[m]() for m in __fnames]