added quacc and a method to allow quantification training and predict with empty classes
This commit is contained in:
parent
e69496246e
commit
b9fed349f0
|
@ -1,28 +1,21 @@
|
|||
import itertools
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from os import makedirs
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
|
||||
from method.aggregative import PACC, EMQ, ACC
|
||||
from utils import *
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.datasets import fetch_rcv1
|
||||
|
||||
import quapy.data.datasets
|
||||
import quapy as qp
|
||||
from quapy.method.aggregative import EMQ, ACC
|
||||
from models_multiclass import *
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.protocol import UPP
|
||||
from quapy.data.datasets import fetch_UCIMulticlassLabelledCollection, UCI_MULTICLASS_DATASETS
|
||||
|
||||
|
||||
def split(data: LabelledCollection):
|
||||
train_val, test = data.split_stratified(train_prop=0.66, random_state=0)
|
||||
train, val = train_val.split_stratified(train_prop=0.5, random_state=0)
|
||||
return train, val, test
|
||||
from quapy.data.datasets import fetch_reviews
|
||||
|
||||
|
||||
def gen_classifiers():
|
||||
|
@ -32,30 +25,103 @@ def gen_classifiers():
|
|||
#yield 'SVM(linear)', LinearSVC()
|
||||
|
||||
|
||||
def gen_datasets()-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||
def gen_multi_datasets(only_names=False)-> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||
for dataset_name in UCI_MULTICLASS_DATASETS:
|
||||
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
|
||||
yield dataset_name, split(dataset)
|
||||
if only_names:
|
||||
yield dataset_name, None
|
||||
else:
|
||||
dataset = fetch_UCIMulticlassLabelledCollection(dataset_name)
|
||||
yield dataset_name, split(dataset)
|
||||
|
||||
|
||||
def gen_bin_datasets(only_names=False) -> [str,[LabelledCollection,LabelledCollection,LabelledCollection]]:
|
||||
if only_names:
|
||||
for dataset_name in ['imdb', 'CCAT', 'GCAT', 'MCAT']:
|
||||
yield dataset_name, None
|
||||
else:
|
||||
train, U = fetch_reviews('imdb', tfidf=True, min_df=10, pickle=True).train_test
|
||||
L, V = train.split_stratified(0.5, random_state=0)
|
||||
yield 'imdb', (L, V, U)
|
||||
|
||||
training = fetch_rcv1(subset='train')
|
||||
test = fetch_rcv1(subset='test')
|
||||
class_names = training.target_names.tolist()
|
||||
for cat in ['CCAT', 'GCAT', 'MCAT']:
|
||||
class_idx = class_names.index(cat)
|
||||
tr_labels = training.target[:,class_idx].toarray().flatten()
|
||||
te_labels = test.target[:,class_idx].toarray().flatten()
|
||||
tr = LabelledCollection(training.data, tr_labels)
|
||||
U = LabelledCollection(test.data, te_labels)
|
||||
L, V = tr.split_stratified(train_prop=0.5, random_state=0)
|
||||
yield cat, (L, V, U)
|
||||
|
||||
|
||||
def gen_CAP(h, acc_fn)->[str, ClassifierAccuracyPrediction]:
|
||||
yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC)
|
||||
yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0)
|
||||
yield 'PabCAP', PabloCAP(h, acc_fn, ACC)
|
||||
#yield 'SebCAP', SebastianiCAP(h, acc_fn, ACC)
|
||||
yield 'SebCAP-SLD', SebastianiCAP(h, acc_fn, EMQ)
|
||||
#yield 'SebCAPweight', SebastianiCAP(h, acc_fn, ACC, alpha=0)
|
||||
#yield 'PabCAP', PabloCAP(h, acc_fn, ACC)
|
||||
yield 'PabCAP-SLD-median', PabloCAP(h, acc_fn, EMQ, aggr='median')
|
||||
|
||||
|
||||
def gen_CAP_cont_table(h)->[str,CAPContingencyTable]:
|
||||
acc_fn = None
|
||||
# yield 'Naive', NaiveCAP(h, acc_fn)
|
||||
yield 'Naive', NaiveCAP(h, acc_fn)
|
||||
yield 'CT-PPS-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()))
|
||||
#yield 'CT-PPSh-ACC', ContTableWithHTransferCAP(h, acc_fn, ACC)
|
||||
yield 'Equations-ACCh', NsquaredEquationsCAP(h, acc_fn, ACC, reuse_h=True)
|
||||
yield 'QuAcc(EMQ)nxn', QuAccNxN(h, acc_fn, EMQ(LogisticRegression()))
|
||||
#yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
||||
yield 'QuAcc(EMQ)1xn2', QuAcc1xN2(h, acc_fn, EMQ(LogisticRegression()))
|
||||
#yield 'CT-PPSh-EMQ', ContTableTransferCAP(h, acc_fn, EMQ(LogisticRegression()), reuse_h=True)
|
||||
#yield 'Equations-ACCh', NsquaredEquationsCAP(h, acc_fn, ACC, reuse_h=True)
|
||||
# yield 'Equations-ACC', NsquaredEquationsCAP(h, acc_fn, ACC)
|
||||
yield 'Equations-SLD', NsquaredEquationsCAP(h, acc_fn, EMQ)
|
||||
#yield 'Equations-SLD', NsquaredEquationsCAP(h, acc_fn, EMQ)
|
||||
|
||||
|
||||
def get_method_names():
|
||||
mock_h = LogisticRegression()
|
||||
return [m for m, _ in gen_CAP(mock_h, None)] + [m for m, _ in gen_CAP_cont_table(mock_h)]
|
||||
|
||||
|
||||
def gen_acc_measure():
|
||||
yield 'vanilla_accuracy', vanilla_acc_fn
|
||||
yield 'macro-F1', macrof1
|
||||
#yield 'macro-F1', macrof1
|
||||
|
||||
|
||||
def split(data: LabelledCollection):
|
||||
train_val, test = data.split_stratified(train_prop=0.66, random_state=0)
|
||||
train, val = train_val.split_stratified(train_prop=0.5, random_state=0)
|
||||
return train, val, test
|
||||
|
||||
|
||||
def fit_method(method, V):
|
||||
tinit = time()
|
||||
method.fit(V)
|
||||
t_train = time() - tinit
|
||||
return method, t_train
|
||||
|
||||
|
||||
def predictionsCAP(method, test_prot):
|
||||
tinit = time()
|
||||
estim_accs = [method.predict(Ui.X) for Ui in test_prot()]
|
||||
t_test_ave = (time() - tinit) / test_prot.total()
|
||||
return estim_accs, t_test_ave
|
||||
|
||||
|
||||
def predictionsCAPcont_table(method, test_prot, gen_acc_measure):
|
||||
estim_accs_dict = {}
|
||||
tinit = time()
|
||||
estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()]
|
||||
for acc_name, acc_fn in gen_acc_measure():
|
||||
estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables]
|
||||
t_test_ave = (time() - tinit) / test_prot.total()
|
||||
return estim_accs_dict, t_test_ave
|
||||
|
||||
|
||||
def any_missing(basedir, cls_name, dataset_name, method_name):
|
||||
for acc_name, _ in gen_acc_measure():
|
||||
if not os.path.exists(getpath(basedir, cls_name, acc_name, dataset_name, method_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def true_acc(h:BaseEstimator, acc_fn: callable, U: LabelledCollection):
|
||||
|
@ -115,4 +181,159 @@ def cap_errors(true_acc, estim_acc):
|
|||
true_acc = np.asarray(true_acc)
|
||||
estim_acc = np.asarray(estim_acc)
|
||||
#return (true_acc - estim_acc)**2
|
||||
return np.abs(true_acc - estim_acc)
|
||||
return np.abs(true_acc - estim_acc)
|
||||
|
||||
|
||||
def plot_diagonal(cls_name, measure_name, results, base_dir='plots'):
|
||||
|
||||
makedirs(base_dir, exist_ok=True)
|
||||
makedirs(join(base_dir, measure_name), exist_ok=True)
|
||||
|
||||
# Create scatter plot
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.xlim(0, 1)
|
||||
plt.ylim(0, 1)
|
||||
plt.plot([0, 1], [0, 1], color='black', linestyle='--')
|
||||
|
||||
for method_name in results.keys():
|
||||
xs = results[method_name]['true_acc']
|
||||
ys = results[method_name]['estim_acc']
|
||||
err = cap_errors(xs, ys).mean()
|
||||
#pear_cor, _ = 0, 0 #pearsonr(xs, ys)
|
||||
plt.scatter(xs, ys, label=f'{method_name} {err:.3f}', alpha=0.6)
|
||||
|
||||
plt.legend()
|
||||
|
||||
# Add labels and title
|
||||
plt.xlabel(f'True {measure_name}')
|
||||
plt.ylabel(f'Estimated {measure_name}')
|
||||
|
||||
# Display the plot
|
||||
# plt.show()
|
||||
plt.savefig(join(base_dir, measure_name, 'diagonal_'+cls_name+'.png'))
|
||||
|
||||
|
||||
def getpath(basedir, cls_name, acc_name, dataset_name, method_name):
|
||||
return f"results/{basedir}/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"
|
||||
|
||||
|
||||
def open_results(basedir, cls_name, acc_name, dataset_name='*', method_name='*'):
|
||||
results = defaultdict(lambda : {'true_acc':[], 'estim_acc':[]})
|
||||
if isinstance(method_name, str):
|
||||
method_name = [method_name]
|
||||
if isinstance(dataset_name, str):
|
||||
dataset_name = [dataset_name]
|
||||
for dataset_, method_ in itertools.product(dataset_name, method_name):
|
||||
path = getpath(basedir, cls_name, acc_name, dataset_, method_)
|
||||
for file in glob(path):
|
||||
#print(file)
|
||||
method = Path(file).name.replace('.json','')
|
||||
result = json.load(open(file, 'r'))
|
||||
results[method]['true_acc'].extend(result['true_acc'])
|
||||
results[method]['estim_acc'].extend(result['estim_acc'])
|
||||
return results
|
||||
|
||||
|
||||
def save_json_file(path, data):
|
||||
os.makedirs(Path(path).parent, exist_ok=True)
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
def save_json_result(path, true_accs, estim_accs, t_train, t_test):
|
||||
result = {
|
||||
't_train': t_train,
|
||||
't_test_ave': t_test,
|
||||
'true_acc': true_accs,
|
||||
'estim_acc': estim_accs
|
||||
}
|
||||
save_json_file(path, result)
|
||||
|
||||
|
||||
def get_dataset_stats(path, test_prot, L, V):
|
||||
test_prevs = [Ui.prevalence() for Ui in test_prot()]
|
||||
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
|
||||
info = {
|
||||
'n_classes': L.n_classes,
|
||||
'n_train': len(L),
|
||||
'n_val': len(V),
|
||||
'train_prev': L.prevalence().tolist(),
|
||||
'val_prev': V.prevalence().tolist(),
|
||||
'test_prevs': [x.tolist() for x in test_prevs],
|
||||
'shifts': [x.tolist() for x in shifts],
|
||||
'sample_size': test_prot.sample_size,
|
||||
'num_samples': test_prot.total()
|
||||
}
|
||||
save_json_file(path, info)
|
||||
|
||||
|
||||
def gen_tables(basedir, datasets):
|
||||
from tabular import Table
|
||||
|
||||
mock_h = LogisticRegression(),
|
||||
methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
|
||||
classifiers = [classifier for classifier, _ in gen_classifiers()]
|
||||
measures = [measure for measure, _ in gen_acc_measure()]
|
||||
|
||||
os.makedirs('tables', exist_ok=True)
|
||||
|
||||
tex_doc = """
|
||||
\\documentclass[10pt,a4paper]{article}
|
||||
\\usepackage[utf8]{inputenc}
|
||||
\\usepackage{amsmath}
|
||||
\\usepackage{amsfonts}
|
||||
\\usepackage{amssymb}
|
||||
\\usepackage{graphicx}
|
||||
\\usepackage{tabularx}
|
||||
\\usepackage{color}
|
||||
\\usepackage{colortbl}
|
||||
\\usepackage{xcolor}
|
||||
\\begin{document}
|
||||
"""
|
||||
|
||||
classifier = classifiers[0]
|
||||
metric = "vanilla_accuracy"
|
||||
|
||||
table = Table(datasets, methods)
|
||||
for method, dataset in itertools.product(methods, datasets):
|
||||
path = getpath(basedir, classifier, metric, dataset, method)
|
||||
if not os.path.exists(path):
|
||||
print('missing ', path)
|
||||
continue
|
||||
results = json.load(open(path, 'r'))
|
||||
true_acc = results['true_acc']
|
||||
estim_acc = np.asarray(results['estim_acc'])
|
||||
if any(np.isnan(estim_acc)):
|
||||
print(f'nan values found in {method=} {dataset=}')
|
||||
continue
|
||||
if any(estim_acc>1.00001):
|
||||
print(f'values >1 found in {method=} {dataset=} [max={estim_acc.max()}]')
|
||||
continue
|
||||
if any(estim_acc<-0.00001):
|
||||
print(f'values <0 found in {method=} {dataset=} [min={estim_acc.min()}]')
|
||||
continue
|
||||
errors = cap_errors(true_acc, estim_acc)
|
||||
table.add(dataset, method, errors)
|
||||
|
||||
tex = table.latexTabular()
|
||||
table_name = f'{basedir}_{classifier}_{metric}.tex'
|
||||
with open(f'./tables/{table_name}', 'wt') as foo:
|
||||
foo.write('\\resizebox{\\textwidth}{!}{%\n')
|
||||
foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
|
||||
foo.write(tex)
|
||||
foo.write('\\end{tabular}%\n')
|
||||
foo.write('}\n')
|
||||
|
||||
tex_doc += "\input{" + table_name + "}\n"
|
||||
|
||||
tex_doc += """
|
||||
\\end{document}
|
||||
"""
|
||||
with open(f'./tables/main.tex', 'wt') as foo:
|
||||
foo.write(tex_doc)
|
||||
|
||||
print("[Tables Done] runing latex")
|
||||
os.chdir('./tables/')
|
||||
os.system('pdflatex main.tex')
|
||||
os.system('rm main.aux main.log')
|
||||
|
||||
|
|
|
@ -1,46 +1,16 @@
|
|||
import itertools
|
||||
import os.path
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from utils import *
|
||||
from models_multiclass import *
|
||||
from quapy.protocol import UPP
|
||||
from commons import *
|
||||
|
||||
PROBLEM = 'multiclass'
|
||||
basedir = PROBLEM
|
||||
|
||||
def fit_method(method, V):
|
||||
tinit = time()
|
||||
method.fit(V)
|
||||
t_train = time() - tinit
|
||||
return method, t_train
|
||||
|
||||
|
||||
def predictionsCAP(method, test_prot):
|
||||
tinit = time()
|
||||
estim_accs = [method.predict(Ui.X) for Ui in test_prot()]
|
||||
t_test_ave = (time() - tinit) / test_prot.total()
|
||||
return estim_accs, t_test_ave
|
||||
|
||||
|
||||
def predictionsCAPcont_table(method, test_prot, gen_acc_measure):
|
||||
estim_accs_dict = {}
|
||||
tinit = time()
|
||||
estim_tables = [method.predict_ct(Ui.X) for Ui in test_prot()]
|
||||
for acc_name, acc_fn in gen_acc_measure():
|
||||
estim_accs_dict[acc_name] = [acc_fn(cont_table) for cont_table in estim_tables]
|
||||
t_test_ave = (time() - tinit) / test_prot.total()
|
||||
return estim_accs_dict, t_test_ave
|
||||
|
||||
|
||||
def any_missing(cls_name, dataset_name, method_name):
|
||||
for acc_name, _ in gen_acc_measure():
|
||||
if not os.path.exists(getpath(cls_name, acc_name, dataset_name, method_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 250
|
||||
NUM_TEST = 100
|
||||
if PROBLEM == 'binary':
|
||||
qp.environ['SAMPLE_SIZE'] = 1000
|
||||
NUM_TEST = 1000
|
||||
gen_datasets = gen_bin_datasets
|
||||
elif PROBLEM == 'multiclass':
|
||||
qp.environ['SAMPLE_SIZE'] = 250
|
||||
NUM_TEST = 100
|
||||
gen_datasets = gen_multi_datasets
|
||||
|
||||
|
||||
for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifiers(), gen_datasets()):
|
||||
|
@ -62,7 +32,7 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
|
|||
# must be nested in the acc-for
|
||||
for acc_name, acc_fn in gen_acc_measure():
|
||||
for (method_name, method) in gen_CAP(h, acc_fn):
|
||||
result_path = getpath(cls_name, acc_name, dataset_name, method_name)
|
||||
result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name)
|
||||
if os.path.exists(result_path):
|
||||
print(f'\t{method_name}-{acc_name} exists, skipping')
|
||||
continue
|
||||
|
@ -75,7 +45,7 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
|
|||
# instances of CAPContingencyTable instead are generic, and the evaluation measure can
|
||||
# be nested to the predictions to speed up things
|
||||
for (method_name, method) in gen_CAP_cont_table(h):
|
||||
if not any_missing(cls_name, dataset_name, method_name):
|
||||
if not any_missing(basedir, cls_name, dataset_name, method_name):
|
||||
print(f'\tmethod {method_name} has all results already computed. Skipping.')
|
||||
continue
|
||||
|
||||
|
@ -84,14 +54,22 @@ for (cls_name, h), (dataset_name, (L, V, U)) in itertools.product(gen_classifier
|
|||
method, t_train = fit_method(method, V)
|
||||
estim_accs_dict, t_test_ave = predictionsCAPcont_table(method, test_prot, gen_acc_measure)
|
||||
for acc_name in estim_accs_dict.keys():
|
||||
result_path = getpath(cls_name, acc_name, dataset_name, method_name)
|
||||
result_path = getpath(basedir, cls_name, acc_name, dataset_name, method_name)
|
||||
save_json_result(result_path, true_accs[acc_name], estim_accs_dict[acc_name], t_train, t_test_ave)
|
||||
|
||||
print()
|
||||
|
||||
# generate diagonal plots
|
||||
print('generating plots')
|
||||
for (cls_name, _), (acc_name, _) in itertools.product(gen_classifiers(), gen_acc_measure()):
|
||||
results = open_results(cls_name, acc_name)
|
||||
plot_diagonal(cls_name, acc_name, results)
|
||||
methods = get_method_names()
|
||||
results = open_results(basedir, cls_name, acc_name, method_name=methods)
|
||||
plot_diagonal(cls_name, acc_name, results, base_dir=f'plots/{basedir}/all')
|
||||
for dataset_name, _ in gen_datasets(only_names=True):
|
||||
results = open_results(basedir, cls_name, acc_name, dataset_name=dataset_name, method_name=methods)
|
||||
plot_diagonal(cls_name, acc_name, results, base_dir=f'plots/{basedir}/{dataset_name}')
|
||||
|
||||
print('generating tables')
|
||||
gen_tables(basedir, datasets=[d for d,_ in gen_datasets(only_names=True)])
|
||||
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from utils import gen_tables
|
||||
from commons import gen_tables
|
||||
|
||||
gen_tables()
|
|
@ -1,3 +1,5 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
@ -14,7 +16,7 @@ from sklearn.model_selection import cross_val_predict
|
|||
|
||||
from quapy.protocol import UPP
|
||||
from quapy.method.base import BaseQuantifier
|
||||
from quapy.method.aggregative import PACC
|
||||
from quapy.method.aggregative import PACC, AggregativeQuantifier
|
||||
import quapy.functional as F
|
||||
|
||||
|
||||
|
@ -102,20 +104,38 @@ class NaiveCAP(CAPContingencyTable):
|
|||
return self.cont_table
|
||||
|
||||
|
||||
class ContTableTransferCAP(CAPContingencyTable):
|
||||
class CAPContingencyTableQ(CAPContingencyTable):
|
||||
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier, reuse_h=False):
|
||||
super().__init__(h, acc)
|
||||
self.reuse_h = reuse_h
|
||||
if reuse_h:
|
||||
assert isinstance(q_class, AggregativeQuantifier), f'quantifier {q_class} is not of type aggregative'
|
||||
self.q = deepcopy(q_class)
|
||||
self.q.set_params(classifier=h)
|
||||
else:
|
||||
self.q = q_class
|
||||
|
||||
def quantifier_fit(self, val: LabelledCollection):
|
||||
if self.reuse_h:
|
||||
self.q.fit(val, fit_classifier=False, val_split=val)
|
||||
else:
|
||||
self.q.fit(val)
|
||||
|
||||
|
||||
class ContTableTransferCAP(CAPContingencyTableQ):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q: BaseQuantifier):
|
||||
super().__init__(h, acc)
|
||||
self.q = q
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class, reuse_h=False):
|
||||
super().__init__(h, acc, q_class, reuse_h)
|
||||
|
||||
def fit(self, val: LabelledCollection):
|
||||
y_hat = self.h.predict(val.X)
|
||||
y_true = val.y
|
||||
self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
|
||||
self.train_prev = val.prevalence()
|
||||
self.q.fit(val)
|
||||
self.quantifier_fit(val)
|
||||
return self
|
||||
|
||||
def predict_ct(self, test):
|
||||
|
@ -128,52 +148,18 @@ class ContTableTransferCAP(CAPContingencyTable):
|
|||
return self.cont_table * adjustment[:, np.newaxis]
|
||||
|
||||
|
||||
class ContTableWithHTransferCAP(CAPContingencyTable):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class):
|
||||
super().__init__(h, acc)
|
||||
self.q = q_class(classifier=h)
|
||||
|
||||
def fit(self, val: LabelledCollection):
|
||||
y_hat = self.h.predict(val.X)
|
||||
y_true = val.y
|
||||
self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
|
||||
self.train_prev = val.prevalence()
|
||||
self.q.fit(val, fit_classifier=False, val_split=val)
|
||||
return self
|
||||
|
||||
def predict_ct(self, test):
|
||||
"""
|
||||
:param test: test collection (ignored)
|
||||
:return: a confusion matrix in the return format of `sklearn.metrics.confusion_matrix`
|
||||
"""
|
||||
test_prev_estim = self.q.quantify(test)
|
||||
adjustment = test_prev_estim / self.train_prev
|
||||
return self.cont_table * adjustment[:, np.newaxis]
|
||||
|
||||
|
||||
class NsquaredEquationsCAP(CAPContingencyTable):
|
||||
class NsquaredEquationsCAP(CAPContingencyTableQ):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class, reuse_h=False):
|
||||
super().__init__(h, acc)
|
||||
self.reuse_h = reuse_h
|
||||
if reuse_h:
|
||||
self.q = q_class(classifier=h)
|
||||
else:
|
||||
self.q = q_class(classifier=LogisticRegression())
|
||||
super().__init__(h, acc, q_class, reuse_h)
|
||||
|
||||
def fit(self, val: LabelledCollection):
|
||||
y_hat = self.h.predict(val.X)
|
||||
y_true = val.y
|
||||
self.cont_table = confusion_matrix(y_true, y_pred=y_hat, labels=val.classes_)
|
||||
if self.reuse_h:
|
||||
self.q.fit(val, fit_classifier=False, val_split=val)
|
||||
else:
|
||||
self.q.fit(val)
|
||||
self.quantifier_fit(val)
|
||||
self.A, self.partial_b = self._construct_equations()
|
||||
return self
|
||||
|
||||
|
@ -247,8 +233,22 @@ class NsquaredEquationsCAP(CAPContingencyTable):
|
|||
b[-2*(n-1):-(n-1)] = cc_prev_estim[1:]
|
||||
b[-(n-1):] = q_prev_estim[1:]
|
||||
|
||||
# try the fast solution (may not be valid)
|
||||
x = np.linalg.solve(A, b)
|
||||
|
||||
if any(x<0) or any(x>0) or not np.isclose(x.sum(), 1):
|
||||
|
||||
print('L', end='')
|
||||
|
||||
# try the iterative solution
|
||||
def loss(x):
|
||||
return np.linalg.norm(A @ x - b, ord=2)
|
||||
|
||||
x = F.optim_minimize(loss, n_classes=n**2)
|
||||
|
||||
else:
|
||||
print('.', end='')
|
||||
|
||||
cont_table_test = x.reshape(n,n)
|
||||
return cont_table_test
|
||||
|
||||
|
@ -334,3 +334,118 @@ class PabloCAP(ClassifierAccuracyPrediction):
|
|||
raise ValueError('unknown aggregation function')
|
||||
|
||||
|
||||
class QuAcc:
|
||||
def _get_X_dot(self, X):
|
||||
h = self.h
|
||||
if hasattr(h, 'predict_proba'):
|
||||
P = h.predict_proba(X)[:, 1:]
|
||||
else:
|
||||
n_classes = len(h.classes_)
|
||||
P = h.decision_function(X).reshape(-1, n_classes)
|
||||
|
||||
X_dot = safehstack(X, P)
|
||||
return X_dot
|
||||
|
||||
|
||||
class QuAcc1xN2(CAPContingencyTableQ, QuAcc):
|
||||
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier):
|
||||
self.h = h
|
||||
self.acc = acc
|
||||
self.q = EmptySaveQuantifier(q_class)
|
||||
|
||||
def fit(self, val: LabelledCollection):
|
||||
pred_labels = self.h.predict(val.X)
|
||||
true_labels = val.y
|
||||
|
||||
n = val.n_classes
|
||||
classes_dot = np.arange(n**2)
|
||||
ct_class_idx = classes_dot.reshape(n, n)
|
||||
|
||||
X_dot = self._get_X_dot(val.X)
|
||||
y_dot = ct_class_idx[true_labels, pred_labels]
|
||||
val_dot = LabelledCollection(X_dot, y_dot, classes=classes_dot)
|
||||
self.q.fit(val_dot)
|
||||
|
||||
def predict_ct(self, X):
|
||||
X_dot = self._get_X_dot(X)
|
||||
return self.q.quantify(X_dot)
|
||||
|
||||
|
||||
class QuAccNxN(CAPContingencyTableQ, QuAcc):
|
||||
|
||||
def __init__(self, h: BaseEstimator, acc: callable, q_class: AggregativeQuantifier):
|
||||
self.h = h
|
||||
self.acc = acc
|
||||
self.q_class = q_class
|
||||
|
||||
def fit(self, val: LabelledCollection):
|
||||
pred_labels = self.h.predict(val.X)
|
||||
true_labels = val.y
|
||||
X_dot = self._get_X_dot(val.X)
|
||||
|
||||
self.q = []
|
||||
for class_i in self.h.classes_:
|
||||
X_dot_i = X_dot[pred_labels==class_i]
|
||||
y_i = true_labels[pred_labels==class_i]
|
||||
data_i = LabelledCollection(X_dot_i, y_i, classes=val.classes_)
|
||||
|
||||
q_i = EmptySaveQuantifier(deepcopy(self.q_class))
|
||||
q_i.fit(data_i)
|
||||
self.q.append(q_i)
|
||||
|
||||
def predict_ct(self, X):
|
||||
classes = self.h.classes_
|
||||
pred_labels = self.h.predict(X)
|
||||
X_dot = self._get_X_dot(X)
|
||||
pred_prev = F.prevalence_from_labels(pred_labels, classes)
|
||||
cont_table = []
|
||||
for class_i, q_i, p_i in zip(classes, self.q, pred_prev):
|
||||
X_dot_i = X_dot[pred_labels==class_i]
|
||||
classcond_cond_table_prevs = q_i.quantify(X_dot_i)
|
||||
cond_table_prevs = p_i * classcond_cond_table_prevs
|
||||
cont_table.append(cond_table_prevs)
|
||||
cont_table = np.vstack(cont_table)
|
||||
return cont_table
|
||||
|
||||
|
||||
def safehstack(X, P):
|
||||
if issparse(X) or issparse(P):
|
||||
XP = scipy.sparse.hstack([X, P])
|
||||
XP = csr_matrix(XP)
|
||||
else:
|
||||
XP = np.hstack([X,P])
|
||||
return XP
|
||||
|
||||
|
||||
class EmptySaveQuantifier(BaseQuantifier):
|
||||
def __init__(self, surrogate_quantifier: BaseQuantifier):
|
||||
self.surrogate = surrogate_quantifier
|
||||
|
||||
def fit(self, data: LabelledCollection):
|
||||
self.n_classes = data.n_classes
|
||||
class_compact_data, self.old_class_idx = data.compact_classes()
|
||||
if self.num_non_empty_classes() > 1:
|
||||
self.surrogate.fit(class_compact_data)
|
||||
return self
|
||||
|
||||
def quantify(self, instances):
|
||||
num_instances = instances.shape[0]
|
||||
if self.num_non_empty_classes() == 0 or num_instances==0:
|
||||
# returns the uniform prevalence vector
|
||||
uniform = np.full(fill_value=1./self.n_classes, shape=self.n_classes, dtype=float)
|
||||
return uniform
|
||||
elif self.num_non_empty_classes() == 1:
|
||||
# returns a prevalence vector with 100% of the mass in the only non empty class
|
||||
prev_vector = np.full(fill_value=0., shape=self.n_classes, dtype=float)
|
||||
prev_vector[self.old_class_idx[0]] = 1
|
||||
return prev_vector
|
||||
else:
|
||||
class_compact_prev = self.surrogate.quantify(instances)
|
||||
prev_vector = np.full(fill_value=0., shape=self.n_classes, dtype=float)
|
||||
prev_vector[self.old_class_idx] = class_compact_prev
|
||||
return prev_vector
|
||||
|
||||
def num_non_empty_classes(self):
|
||||
return len(self.old_class_idx)
|
||||
|
||||
|
|
|
@ -1,164 +0,0 @@
|
|||
import itertools
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from os import makedirs
|
||||
from os.path import join
|
||||
import numpy as np
|
||||
import json
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from time import time
|
||||
import quapy as qp
|
||||
from glob import glob
|
||||
|
||||
from commons import cap_errors
|
||||
from models_multiclass import ClassifierAccuracyPrediction, CAPContingencyTable
|
||||
|
||||
|
||||
def plot_diagonal(cls_name, measure_name, results, base_dir='plots'):
|
||||
|
||||
makedirs(base_dir, exist_ok=True)
|
||||
makedirs(join(base_dir, measure_name), exist_ok=True)
|
||||
|
||||
# Create scatter plot
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.xlim(0, 1)
|
||||
plt.ylim(0, 1)
|
||||
plt.plot([0, 1], [0, 1], color='black', linestyle='--')
|
||||
|
||||
for method_name in results.keys():
|
||||
print(method_name, measure_name)
|
||||
xs = results[method_name]['true_acc']
|
||||
ys = results[method_name]['estim_acc']
|
||||
print('max xs', np.max(xs))
|
||||
print('max ys', np.max(ys))
|
||||
err = cap_errors(xs, ys).mean()
|
||||
#pear_cor, _ = 0, 0 #pearsonr(xs, ys)
|
||||
plt.scatter(xs, ys, label=f'{method_name} {err:.3f}', alpha=0.6)
|
||||
|
||||
plt.legend()
|
||||
|
||||
# Add labels and title
|
||||
plt.xlabel(f'True {measure_name}')
|
||||
plt.ylabel(f'Estimated {measure_name}')
|
||||
|
||||
# Display the plot
|
||||
# plt.show()
|
||||
plt.savefig(join(base_dir, measure_name, 'diagonal_'+cls_name+'.png'))
|
||||
|
||||
|
||||
def getpath(cls_name, acc_name, dataset_name, method_name):
|
||||
return f"results/{cls_name}/{acc_name}/{dataset_name}/{method_name}.json"
|
||||
|
||||
|
||||
def open_results(cls_name, acc_name, dataset_name='*', method_name='*'):
|
||||
path = getpath(cls_name, acc_name, dataset_name, method_name)
|
||||
results = defaultdict(lambda : {'true_acc':[], 'estim_acc':[]})
|
||||
for file in glob(path):
|
||||
#print(file)
|
||||
method = Path(file).name.replace('.json','')
|
||||
result = json.load(open(file, 'r'))
|
||||
results[method]['true_acc'].extend(result['true_acc'])
|
||||
results[method]['estim_acc'].extend(result['estim_acc'])
|
||||
return results
|
||||
|
||||
|
||||
def save_json_file(path, data):
|
||||
os.makedirs(Path(path).parent, exist_ok=True)
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
def save_json_result(path, true_accs, estim_accs, t_train, t_test):
|
||||
result = {
|
||||
't_train': t_train,
|
||||
't_test_ave': t_test,
|
||||
'true_acc': true_accs,
|
||||
'estim_acc': estim_accs
|
||||
}
|
||||
save_json_file(path, result)
|
||||
|
||||
|
||||
def get_dataset_stats(path, test_prot, L, V):
|
||||
test_prevs = [Ui.prevalence() for Ui in test_prot()]
|
||||
shifts = [qp.error.ae(L.prevalence(), Ui_prev) for Ui_prev in test_prevs]
|
||||
info = {
|
||||
'n_classes': L.n_classes,
|
||||
'n_train': len(L),
|
||||
'n_val': len(V),
|
||||
'train_prev': L.prevalence().tolist(),
|
||||
'val_prev': V.prevalence().tolist(),
|
||||
'test_prevs': [x.tolist() for x in test_prevs],
|
||||
'shifts': [x.tolist() for x in shifts],
|
||||
'sample_size': test_prot.sample_size,
|
||||
'num_samples': test_prot.total()
|
||||
}
|
||||
save_json_file(path, info)
|
||||
|
||||
|
||||
def gen_tables():
|
||||
from commons import gen_datasets, gen_classifiers, gen_acc_measure, gen_CAP, gen_CAP_cont_table
|
||||
from tabular import Table
|
||||
|
||||
mock_h = LogisticRegression(),
|
||||
methods = [method for method, _ in gen_CAP(mock_h, None)] + [method for method, _ in gen_CAP_cont_table(mock_h)]
|
||||
datasets = [dataset for dataset, _ in gen_datasets()]
|
||||
classifiers = [classifier for classifier, _ in gen_classifiers()]
|
||||
measures = [measure for measure, _ in gen_acc_measure()]
|
||||
|
||||
os.makedirs('tables', exist_ok=True)
|
||||
|
||||
tex_doc = """
|
||||
\\documentclass[10pt,a4paper]{article}
|
||||
\\usepackage[utf8]{inputenc}
|
||||
\\usepackage{amsmath}
|
||||
\\usepackage{amsfonts}
|
||||
\\usepackage{amssymb}
|
||||
\\usepackage{graphicx}
|
||||
\\usepackage{tabularx}
|
||||
\\usepackage{color}
|
||||
\\usepackage{colortbl}
|
||||
\\usepackage{xcolor}
|
||||
\\begin{document}
|
||||
"""
|
||||
|
||||
classifier = classifiers[0]
|
||||
metric = "vanilla_accuracy"
|
||||
|
||||
table = Table(datasets, methods)
|
||||
for method, dataset in itertools.product(methods, datasets):
|
||||
path = f'results/{classifier}/{metric}/{dataset}/{method}.json'
|
||||
results = json.load(open(path, 'r'))
|
||||
true_acc = results['true_acc']
|
||||
estim_acc = np.asarray(results['estim_acc'])
|
||||
if any(np.isnan(estim_acc)) or any(estim_acc>1) or any(estim_acc<0):
|
||||
print(f'error in {method=} {dataset=}')
|
||||
continue
|
||||
errors = cap_errors(true_acc, estim_acc)
|
||||
table.add(dataset, method, errors)
|
||||
|
||||
tex = table.latexTabular()
|
||||
table_name = f'{classifier}_{metric}.tex'
|
||||
with open(f'./tables/{table_name}', 'wt') as foo:
|
||||
foo.write('\\resizebox{\\textwidth}{!}{%\n')
|
||||
foo.write('\\begin{tabular}{c|'+('c'*len(methods))+'}\n')
|
||||
foo.write(tex)
|
||||
foo.write('\\end{tabular}%\n')
|
||||
foo.write('}\n')
|
||||
|
||||
tex_doc += "\input{" + table_name + "}\n"
|
||||
|
||||
tex_doc += """
|
||||
\\end{document}
|
||||
"""
|
||||
with open(f'./tables/main.tex', 'wt') as foo:
|
||||
foo.write(tex_doc)
|
||||
|
||||
print("[Tables Done] runing latex")
|
||||
os.chdir('./tables/')
|
||||
os.system('pdflatex main.tex')
|
||||
os.system('rm main.aux main.bbl main.blg main.log main.out main.dvi')
|
||||
|
|
@ -232,11 +232,24 @@ class LabelledCollection:
|
|||
:return: two instances of :class:`LabelledCollection`, the first one with `train_prop` elements, and the
|
||||
second one with `1-train_prop` elements
|
||||
"""
|
||||
instances = self.instances
|
||||
labels = self.labels
|
||||
remainder = None
|
||||
for idx in np.argwhere(self.counts()==1):
|
||||
class_with_1 = self.classes_[idx.item()]
|
||||
if remainder is None:
|
||||
remainder = LabelledCollection(instances[labels==class_with_1], [class_with_1], classes=self.classes_)
|
||||
else:
|
||||
remainder += LabelledCollection(instances[labels==class_with_1], [class_with_1], classes=self.classes_)
|
||||
instances = instances[labels!=class_with_1]
|
||||
labels = labels[labels!=class_with_1]
|
||||
tr_docs, te_docs, tr_labels, te_labels = train_test_split(
|
||||
self.instances, self.labels, train_size=train_prop, stratify=self.labels, random_state=random_state
|
||||
instances, labels, train_size=train_prop, stratify=labels, random_state=random_state
|
||||
)
|
||||
training = LabelledCollection(tr_docs, tr_labels, classes=self.classes_)
|
||||
test = LabelledCollection(te_docs, te_labels, classes=self.classes_)
|
||||
if remainder is not None:
|
||||
training += remainder
|
||||
return training, test
|
||||
|
||||
def split_random(self, train_prop=0.6, random_state=None):
|
||||
|
@ -414,6 +427,47 @@ class LabelledCollection:
|
|||
test = self.sampling_from_index(test_index)
|
||||
yield train, test
|
||||
|
||||
def empty_classes(self):
|
||||
"""
|
||||
Returns a np.ndarray of empty classes (classes present in self.classes_ but with
|
||||
no positive instance). In case there is none, then an empty np.ndarray is returned
|
||||
|
||||
:return: np.ndarray
|
||||
"""
|
||||
idx = np.argwhere(self.counts()==0).flatten()
|
||||
return self.classes_[idx]
|
||||
|
||||
def non_empty_classes(self):
|
||||
"""
|
||||
Returns a np.ndarray of non-empty classes (classes present in self.classes_ but with
|
||||
at least one positive instance). In case there is none, then an empty np.ndarray is returned
|
||||
|
||||
:return: np.ndarray
|
||||
"""
|
||||
idx = np.argwhere(self.counts() > 0).flatten()
|
||||
return self.classes_[idx]
|
||||
|
||||
def has_empty_classes(self):
|
||||
"""
|
||||
Checks whether the collection has empty classes
|
||||
|
||||
:return: boolean
|
||||
"""
|
||||
return len(self.empty_classes()) > 0
|
||||
|
||||
def compact_classes(self):
|
||||
"""
|
||||
Generates a new LabelledCollection object with no empty classes. It also returns a np.ndarray of
|
||||
indexes that correspond to the old indexes of the new self.classes_.
|
||||
|
||||
:return: (LabelledCollection, np.ndarray,)
|
||||
"""
|
||||
non_empty = self.non_empty_classes()
|
||||
all_classes = self.classes_
|
||||
old_pos = np.searchsorted(all_classes, non_empty)
|
||||
non_empty_collection = LabelledCollection(*self.Xy, classes=non_empty)
|
||||
return non_empty_collection, old_pos
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue