QuaPy/ClassifierAccuracy/main.py

77 lines
2.1 KiB
Python
Raw Normal View History

2024-02-23 16:55:14 +01:00
from collections import defaultdict
2024-02-23 18:19:00 +01:00
from sklearn.base import BaseEstimator
2024-02-23 16:55:14 +01:00
from sklearn.linear_model import LogisticRegression
2024-02-23 18:19:00 +01:00
import numpy as np
from sklearn.metrics import confusion_matrix
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
from method.aggregative import PACC, EMQ
from utils import *
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
import quapy.data.datasets
import quapy as qp
from models_multiclass import *
from quapy.data import LabelledCollection
from quapy.protocol import UPP
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
def split(data: LabelledCollection):
train_val, test = data.split_stratified(train_prop=0.66)
train, val = train_val.split_stratified(train_prop=0.5)
return train, val, test
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
def gen_datasets()-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
for dataset_name in UCI_MULTICLASS_DATASETS:
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
yield dataset_name, split(dataset)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
def gen_CAP(h, acc_fn)->[str,ClassifierAccuracyPrediction]:
yield 'Naive', NaiveCAP(h, acc_fn)
yield 'CT-PPS-PACC', ContTableTransferCAP(h, acc_fn, PACC(LogisticRegression()))
yield 'CT-PPSh-PACC', ContTableWithHTransferCAP(h, acc_fn, PACC)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
y_pred = h.predict(U.X)
y_true = U.y
conf_table = confusion_matrix(y_true, y_pred=y_pred, labels=U.classes_)
return acc_fn(conf_table)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
def acc_fn(cont_table):
return np.diag(cont_table).sum() / cont_table.sum()
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
qp.environ['SAMPLE_SIZE'] = 100
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
h = LogisticRegression()
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
acc_trues = []
acc_predicted = defaultdict(lambda :[])
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
for dataset_name, (L, V, U) in gen_datasets():
print(dataset_name)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
h.fit(*L.Xy)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
test_prot = UPP(U, repeats=100, return_type='labelled_collection')
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
acc_trues.extend(true_acc(h, acc_fn, Ui) for Ui in test_prot())
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
for method_name, method in gen_CAP(h, acc_fn):
method.fit(V)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
for Ui in test_prot():
acc_hat = method.predict(Ui.X)
acc_predicted[method_name].append(acc_hat)
2024-02-23 16:55:14 +01:00
2024-02-23 18:19:00 +01:00
acc_predicted = list(acc_predicted.items())
plot_diagonal('./plots/diagonal.png', acc_trues, acc_predicted)
2024-02-23 16:55:14 +01:00