ATC baseline added, rcv1 dataset added
This commit is contained in:
parent
b47b229ba7
commit
37392c6545
|
@ -0,0 +1,34 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def get_entropy(probs):
|
||||||
|
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
|
||||||
|
|
||||||
|
def get_max_conf(probs):
|
||||||
|
return np.max(probs, axis=-1)
|
||||||
|
|
||||||
|
def find_ATC_threshold(scores, labels):
|
||||||
|
sorted_idx = np.argsort(scores)
|
||||||
|
|
||||||
|
sorted_scores = scores[sorted_idx]
|
||||||
|
sorted_labels = labels[sorted_idx]
|
||||||
|
|
||||||
|
fp = np.sum(labels==0)
|
||||||
|
fn = 0.0
|
||||||
|
|
||||||
|
min_fp_fn = np.abs(fp - fn)
|
||||||
|
thres = 0.0
|
||||||
|
for i in range(len(labels)):
|
||||||
|
if sorted_labels[i] == 0:
|
||||||
|
fp -= 1
|
||||||
|
else:
|
||||||
|
fn += 1
|
||||||
|
|
||||||
|
if np.abs(fp - fn) < min_fp_fn:
|
||||||
|
min_fp_fn = np.abs(fp - fn)
|
||||||
|
thres = sorted_scores[i]
|
||||||
|
|
||||||
|
return min_fp_fn, thres
|
||||||
|
|
||||||
|
|
||||||
|
def get_ATC_acc(thres, scores):
|
||||||
|
return np.mean(scores>=thres)*100.0
|
Binary file not shown.
|
@ -1,14 +1,77 @@
|
||||||
|
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.model_selection import cross_validate
|
from sklearn.model_selection import cross_validate
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
|
from garg22_ATC.ATC_helper import (
|
||||||
|
find_ATC_threshold,
|
||||||
|
get_ATC_acc,
|
||||||
|
get_entropy,
|
||||||
|
get_max_conf,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict:
|
|
||||||
|
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
||||||
scoring = ["f1_macro"]
|
scoring = ["f1_macro"]
|
||||||
scores = cross_validate(c_model, train.X, train.y, scoring=scoring)
|
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
|
||||||
|
return {"f1_score": mean(scores["test_f1_macro"])}
|
||||||
|
|
||||||
|
|
||||||
|
def ATC_MC(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
test: LabelledCollection,
|
||||||
|
predict_method="predict_proba",
|
||||||
|
):
|
||||||
|
c_model_predict = getattr(c_model, predict_method)
|
||||||
|
|
||||||
|
## Load ID validation data probs and labels
|
||||||
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
||||||
|
|
||||||
|
## Load OOD test data probs
|
||||||
|
test_probs = c_model_predict(test.X)
|
||||||
|
|
||||||
|
## score function, e.g., negative entropy or argmax confidence
|
||||||
|
val_scores = get_max_conf(val_probs)
|
||||||
|
val_preds = np.argmax(val_probs, axis=-1)
|
||||||
|
|
||||||
|
test_scores = get_max_conf(test_probs)
|
||||||
|
|
||||||
|
_, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||||
|
ATC_accuracy = get_ATC_acc(ATC_thres, test_scores)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"f1_score": mean(scores["test_f1_macro"])
|
"true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||||
|
"pred_acc": ATC_accuracy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ATC_NE(
|
||||||
|
c_model: BaseEstimator,
|
||||||
|
validation: LabelledCollection,
|
||||||
|
test: LabelledCollection,
|
||||||
|
predict_method="predict_proba",
|
||||||
|
):
|
||||||
|
c_model_predict = getattr(c_model, predict_method)
|
||||||
|
|
||||||
|
## Load ID validation data probs and labels
|
||||||
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
||||||
|
|
||||||
|
## Load OOD test data probs
|
||||||
|
test_probs = c_model_predict(test.X)
|
||||||
|
|
||||||
|
## score function, e.g., negative entropy or argmax confidence
|
||||||
|
val_scores = get_entropy(val_probs)
|
||||||
|
val_preds = np.argmax(val_probs, axis=-1)
|
||||||
|
|
||||||
|
test_scores = get_entropy(test_probs)
|
||||||
|
|
||||||
|
_, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||||
|
ATC_accuracy = get_ATC_acc(ATC_thres, test_scores)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||||
|
"pred_acc": ATC_accuracy
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,35 @@
|
||||||
|
from typing import Tuple
|
||||||
|
import numpy as np
|
||||||
|
from quapy.data.base import LabelledCollection
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
from sklearn.conftest import fetch_rcv1
|
||||||
|
|
||||||
def get_imdb_traintest():
|
TRAIN_VAL_PROP = 0.5
|
||||||
return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
|
|
||||||
|
|
||||||
def get_spambase_traintest():
|
|
||||||
return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
def get_imdb() -> Tuple[LabelledCollection]:
|
||||||
|
train, test = qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
|
||||||
|
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||||
|
return train, validation, test
|
||||||
|
|
||||||
|
|
||||||
|
def get_spambase():
|
||||||
|
train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
|
||||||
|
train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP)
|
||||||
|
return train, validation, test
|
||||||
|
|
||||||
|
|
||||||
|
def get_rcv1(sample_size=100):
|
||||||
|
dataset = fetch_rcv1()
|
||||||
|
|
||||||
|
target_labels = [
|
||||||
|
(target, dataset.target[:, ind].toarray().flatten())
|
||||||
|
for (ind, target) in enumerate(dataset.target_names)
|
||||||
|
]
|
||||||
|
filtered_target_labels = filter(
|
||||||
|
lambda _, labels: np.sum(labels) >= sample_size, target_labels
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
target: LabelledCollection(dataset.data, labels, classes=[0, 1])
|
||||||
|
for (target, labels) in filtered_target_labels
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ from quacc.estimator import (
|
||||||
MulticlassAccuracyEstimator,
|
MulticlassAccuracyEstimator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from quacc.dataset import get_imdb_traintest
|
from quacc.dataset import get_imdb
|
||||||
|
|
||||||
qp.environ["SAMPLE_SIZE"] = 100
|
qp.environ["SAMPLE_SIZE"] = 100
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ dataset_name = "imdb"
|
||||||
|
|
||||||
def estimate_multiclass():
|
def estimate_multiclass():
|
||||||
print(dataset_name)
|
print(dataset_name)
|
||||||
train, test = get_imdb_traintest(dataset_name)
|
train, validation, test = get_imdb(dataset_name)
|
||||||
|
|
||||||
model = LogisticRegression()
|
model = LogisticRegression()
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def estimate_multiclass():
|
||||||
|
|
||||||
def estimate_binary():
|
def estimate_binary():
|
||||||
print(dataset_name)
|
print(dataset_name)
|
||||||
train, test = get_imdb_traintest(dataset_name)
|
train, validation, test = get_imdb(dataset_name)
|
||||||
|
|
||||||
model = LogisticRegression()
|
model = LogisticRegression()
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from quacc.baseline import kfcv
|
from quacc.baseline import kfcv
|
||||||
from quacc.dataset import get_spambase_traintest
|
from quacc.dataset import get_spambase
|
||||||
|
|
||||||
|
|
||||||
class TestBaseline:
|
class TestBaseline:
|
||||||
|
|
||||||
def test_kfcv(self):
|
def test_kfcv(self):
|
||||||
train, _ = get_spambase_traintest()
|
train, _, _ = get_spambase()
|
||||||
c_model = LogisticRegression()
|
c_model = LogisticRegression()
|
||||||
assert "f1_score" in kfcv(c_model, train)
|
assert "f1_score" in kfcv(c_model, train)
|
Loading…
Reference in New Issue