ATC baseline added, rcv1 dataset added
This commit is contained in:
parent
b47b229ba7
commit
37392c6545
|
@ -0,0 +1,34 @@
|
|||
import numpy as np
|
||||
|
||||
def get_entropy(probs):
|
||||
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
|
||||
|
||||
def get_max_conf(probs):
|
||||
return np.max(probs, axis=-1)
|
||||
|
||||
def find_ATC_threshold(scores, labels):
|
||||
sorted_idx = np.argsort(scores)
|
||||
|
||||
sorted_scores = scores[sorted_idx]
|
||||
sorted_labels = labels[sorted_idx]
|
||||
|
||||
fp = np.sum(labels==0)
|
||||
fn = 0.0
|
||||
|
||||
min_fp_fn = np.abs(fp - fn)
|
||||
thres = 0.0
|
||||
for i in range(len(labels)):
|
||||
if sorted_labels[i] == 0:
|
||||
fp -= 1
|
||||
else:
|
||||
fn += 1
|
||||
|
||||
if np.abs(fp - fn) < min_fp_fn:
|
||||
min_fp_fn = np.abs(fp - fn)
|
||||
thres = sorted_scores[i]
|
||||
|
||||
return min_fp_fn, thres
|
||||
|
||||
|
||||
def get_ATC_acc(thres, scores):
|
||||
return np.mean(scores>=thres)*100.0
|
Binary file not shown.
|
@ -1,14 +1,77 @@
|
|||
|
||||
from statistics import mean
|
||||
from typing import Dict
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.model_selection import cross_validate
|
||||
from quapy.data import LabelledCollection
|
||||
from garg22_ATC.ATC_helper import (
|
||||
find_ATC_threshold,
|
||||
get_ATC_acc,
|
||||
get_entropy,
|
||||
get_max_conf,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict:
|
||||
|
||||
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
||||
scoring = ["f1_macro"]
|
||||
scores = cross_validate(c_model, train.X, train.y, scoring=scoring)
|
||||
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
|
||||
return {"f1_score": mean(scores["test_f1_macro"])}
|
||||
|
||||
|
||||
def ATC_MC(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
test: LabelledCollection,
|
||||
predict_method="predict_proba",
|
||||
):
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
|
||||
## Load ID validation data probs and labels
|
||||
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
||||
|
||||
## Load OOD test data probs
|
||||
test_probs = c_model_predict(test.X)
|
||||
|
||||
## score function, e.g., negative entropy or argmax confidence
|
||||
val_scores = get_max_conf(val_probs)
|
||||
val_preds = np.argmax(val_probs, axis=-1)
|
||||
|
||||
test_scores = get_max_conf(test_probs)
|
||||
|
||||
_, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||
ATC_accuracy = get_ATC_acc(ATC_thres, test_scores)
|
||||
|
||||
return {
|
||||
"f1_score": mean(scores["test_f1_macro"])
|
||||
"true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||
"pred_acc": ATC_accuracy
|
||||
}
|
||||
|
||||
def ATC_NE(
|
||||
c_model: BaseEstimator,
|
||||
validation: LabelledCollection,
|
||||
test: LabelledCollection,
|
||||
predict_method="predict_proba",
|
||||
):
|
||||
c_model_predict = getattr(c_model, predict_method)
|
||||
|
||||
## Load ID validation data probs and labels
|
||||
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
||||
|
||||
## Load OOD test data probs
|
||||
test_probs = c_model_predict(test.X)
|
||||
|
||||
## score function, e.g., negative entropy or argmax confidence
|
||||
val_scores = get_entropy(val_probs)
|
||||
val_preds = np.argmax(val_probs, axis=-1)
|
||||
|
||||
test_scores = get_entropy(test_probs)
|
||||
|
||||
_, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||
ATC_accuracy = get_ATC_acc(ATC_thres, test_scores)
|
||||
|
||||
return {
|
||||
"true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||
"pred_acc": ATC_accuracy
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,35 @@
|
|||
from typing import Tuple
|
||||
import numpy as np
|
||||
from quapy.data.base import LabelledCollection
|
||||
import quapy as qp
|
||||
from sklearn.conftest import fetch_rcv1
|
||||
|
||||
def get_imdb_traintest():
|
||||
return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
|
||||
TRAIN_VAL_PROP = 0.5
|
||||
|
||||
def get_spambase_traintest():
|
||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||
|
||||
def get_imdb() -> Tuple[LabelledCollection]:
|
||||
train, test = qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
|
||||
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||
return train, validation, test
|
||||
|
||||
|
||||
def get_spambase():
|
||||
train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||
return train, validation, test
|
||||
|
||||
|
||||
def get_rcv1(sample_size=100):
|
||||
dataset = fetch_rcv1()
|
||||
|
||||
target_labels = [
|
||||
(target, dataset.target[:, ind].toarray().flatten())
|
||||
for (ind, target) in enumerate(dataset.target_names)
|
||||
]
|
||||
filtered_target_labels = filter(
|
||||
lambda _, labels: np.sum(labels) >= sample_size, target_labels
|
||||
)
|
||||
return {
|
||||
target: LabelledCollection(dataset.data, labels, classes=[0, 1])
|
||||
for (target, labels) in filtered_target_labels
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ from quacc.estimator import (
|
|||
MulticlassAccuracyEstimator,
|
||||
)
|
||||
|
||||
from quacc.dataset import get_imdb_traintest
|
||||
from quacc.dataset import get_imdb
|
||||
|
||||
qp.environ["SAMPLE_SIZE"] = 100
|
||||
|
||||
|
@ -20,7 +20,7 @@ dataset_name = "imdb"
|
|||
|
||||
def estimate_multiclass():
|
||||
print(dataset_name)
|
||||
train, test = get_imdb_traintest(dataset_name)
|
||||
train, validation, test = get_imdb(dataset_name)
|
||||
|
||||
model = LogisticRegression()
|
||||
|
||||
|
@ -59,7 +59,7 @@ def estimate_multiclass():
|
|||
|
||||
def estimate_binary():
|
||||
print(dataset_name)
|
||||
train, test = get_imdb_traintest(dataset_name)
|
||||
train, validation, test = get_imdb(dataset_name)
|
||||
|
||||
model = LogisticRegression()
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from quacc.baseline import kfcv
|
||||
from quacc.dataset import get_spambase_traintest
|
||||
from quacc.dataset import get_spambase
|
||||
|
||||
|
||||
class TestBaseline:
|
||||
|
||||
def test_kfcv(self):
|
||||
train, _ = get_spambase_traintest()
|
||||
train, _, _ = get_spambase()
|
||||
c_model = LogisticRegression()
|
||||
assert "f1_score" in kfcv(c_model, train)
|
Loading…
Reference in New Issue