2023-09-28 17:47:39 +02:00
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
2023-12-11 16:43:45 +01:00
|
|
|
from distribution_matching.method.kdex import KDExML
|
|
|
|
from distribution_matching.method.method_kdey import KDEy
|
|
|
|
from distribution_matching.method.method_kdey_closed_efficient_correct import KDEyclosed_efficient_corr
|
|
|
|
from distribution_matching.method.kdey import KDEyCS, KDEyHD, KDEyML
|
2023-09-28 17:47:39 +02:00
|
|
|
from quapy.method.aggregative import EMQ, CC, PCC, DistributionMatching, PACC, HDy, OneVsAllAggregative, ACC
|
2023-12-11 16:43:45 +01:00
|
|
|
from distribution_matching.method.dirichlety import DIRy
|
2023-09-28 17:47:39 +02:00
|
|
|
from sklearn.linear_model import LogisticRegression
|
2023-12-06 16:55:06 +01:00
|
|
|
|
2023-12-11 16:43:45 +01:00
|
|
|
# set to True to get the full list of methods tested in the paper (reported in the appendix)
|
|
|
|
# set to False to get the reduced list (shown in the body of the paper)
|
2023-12-17 20:14:38 +01:00
|
|
|
FULL_METHOD_LIST = False
|
2023-12-06 16:55:06 +01:00
|
|
|
|
2023-12-11 16:43:45 +01:00
|
|
|
if FULL_METHOD_LIST:
|
|
|
|
ADJUSTMENT_METHODS = ['ACC', 'PACC']
|
|
|
|
DISTR_MATCH_METHODS = ['HDy-OvA', 'DM-T', 'DM-HD', 'KDEy-HD', 'DM-CS', 'KDEy-CS']
|
2023-12-17 20:14:38 +01:00
|
|
|
MAX_LIKE_METHODS = ['DIR', 'EMQ', 'EMQ-BCTS', 'KDEy-ML']
|
2023-12-11 16:43:45 +01:00
|
|
|
else:
|
|
|
|
ADJUSTMENT_METHODS = ['PACC']
|
|
|
|
DISTR_MATCH_METHODS = ['DM-T', 'DM-HD', 'KDEy-HD', 'DM-CS', 'KDEy-CS']
|
2023-12-17 20:14:38 +01:00
|
|
|
MAX_LIKE_METHODS = ['EMQ', 'KDEy-ML']
|
2023-09-28 17:47:39 +02:00
|
|
|
|
2023-12-11 16:43:45 +01:00
|
|
|
# list of methods to consider
|
|
|
|
METHODS = ADJUSTMENT_METHODS + DISTR_MATCH_METHODS + MAX_LIKE_METHODS
|
2023-10-02 17:50:12 +02:00
|
|
|
BIN_METHODS = [x.replace('-OvA', '') for x in METHODS]
|
2023-09-28 17:47:39 +02:00
|
|
|
|
2023-12-11 16:43:45 +01:00
|
|
|
# common hyperparameterss
|
2023-09-28 17:47:39 +02:00
|
|
|
hyper_LR = {
|
|
|
|
'classifier__C': np.logspace(-3,3,7),
|
|
|
|
'classifier__class_weight': ['balanced', None]
|
2023-10-13 17:34:26 +02:00
|
|
|
}
|
2023-09-28 17:47:39 +02:00
|
|
|
|
2023-12-06 16:55:06 +01:00
|
|
|
hyper_kde = {
|
|
|
|
'bandwidth': np.linspace(0.01, 0.2, 20)
|
|
|
|
}
|
|
|
|
|
|
|
|
nbins_range = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 64]
|
|
|
|
|
2023-09-28 17:47:39 +02:00
|
|
|
|
2023-12-11 16:43:45 +01:00
|
|
|
# instances a new quantifier based on a string name
|
|
|
|
def new_method(method, **lr_kwargs):
|
2023-09-28 17:47:39 +02:00
|
|
|
lr = LogisticRegression(**lr_kwargs)
|
|
|
|
|
|
|
|
if method == 'CC':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = CC(lr)
|
|
|
|
elif method == 'PCC':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = PCC(lr)
|
|
|
|
elif method == 'ACC':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = ACC(lr)
|
|
|
|
elif method == 'PACC':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = PACC(lr)
|
2023-12-06 16:55:06 +01:00
|
|
|
elif method in ['KDEy-HD']:
|
2023-12-06 19:24:42 +01:00
|
|
|
param_grid = {**hyper_kde, **hyper_LR}
|
|
|
|
quantifier = KDEyHD(lr)
|
2023-12-06 16:55:06 +01:00
|
|
|
elif method == 'KDEy-CS':
|
2023-12-06 19:24:42 +01:00
|
|
|
param_grid = {**hyper_kde, **hyper_LR}
|
|
|
|
quantifier = KDEyCS(lr)
|
2023-09-28 17:47:39 +02:00
|
|
|
elif method == 'KDEy-ML':
|
2023-12-06 19:24:42 +01:00
|
|
|
param_grid = {**hyper_kde, **hyper_LR}
|
|
|
|
quantifier = KDEyML(lr)
|
2023-12-11 16:43:45 +01:00
|
|
|
elif method == 'KDEx-ML':
|
|
|
|
param_grid = {
|
|
|
|
'bandwidth': np.linspace(0.001, 2, 501)
|
|
|
|
}
|
|
|
|
quantifier = KDExML()
|
2023-09-28 17:47:39 +02:00
|
|
|
elif method == 'DIR':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = DIRy(lr)
|
|
|
|
elif method == 'EMQ':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = EMQ(lr)
|
2023-12-06 16:55:06 +01:00
|
|
|
elif method == 'EMQ-BCTS':
|
2023-10-23 11:32:35 +02:00
|
|
|
method_params = {'exact_train_prev': [False], 'recalib': ['bcts']}
|
|
|
|
param_grid = {**method_params, **hyper_LR}
|
|
|
|
quantifier = EMQ(lr)
|
2023-10-30 09:41:52 +01:00
|
|
|
elif method == 'HDy':
|
|
|
|
param_grid = hyper_LR
|
|
|
|
quantifier = HDy(lr)
|
2023-09-28 17:47:39 +02:00
|
|
|
elif method == 'HDy-OvA':
|
|
|
|
param_grid = {'binary_quantifier__' + key: val for key, val in hyper_LR.items()}
|
|
|
|
quantifier = OneVsAllAggregative(HDy(lr))
|
2023-10-23 11:32:35 +02:00
|
|
|
elif method == 'DM-T':
|
|
|
|
method_params = {
|
2023-12-06 16:55:06 +01:00
|
|
|
'nbins': nbins_range,
|
2023-10-23 11:32:35 +02:00
|
|
|
'val_split': [10],
|
|
|
|
'divergence': ['topsoe']
|
|
|
|
}
|
|
|
|
param_grid = {**method_params, **hyper_LR}
|
|
|
|
quantifier = DistributionMatching(lr)
|
|
|
|
elif method == 'DM-HD':
|
|
|
|
method_params = {
|
2023-12-06 16:55:06 +01:00
|
|
|
'nbins': nbins_range,
|
2023-10-23 11:32:35 +02:00
|
|
|
'val_split': [10],
|
|
|
|
'divergence': ['HD']
|
|
|
|
}
|
|
|
|
param_grid = {**method_params, **hyper_LR}
|
|
|
|
quantifier = DistributionMatching(lr)
|
|
|
|
elif method == 'DM-CS':
|
|
|
|
method_params = {
|
2023-12-06 16:55:06 +01:00
|
|
|
'nbins': nbins_range,
|
2023-10-23 11:32:35 +02:00
|
|
|
'val_split': [10],
|
|
|
|
'divergence': ['CS']
|
|
|
|
}
|
|
|
|
param_grid = {**method_params, **hyper_LR}
|
|
|
|
quantifier = DistributionMatching(lr)
|
2023-09-28 17:47:39 +02:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('unknown method', method)
|
|
|
|
|
|
|
|
return param_grid, quantifier
|
|
|
|
|
|
|
|
|
|
|
|
def show_results(result_path):
|
|
|
|
df = pd.read_csv(result_path+'.csv', sep='\t')
|
|
|
|
|
|
|
|
pd.set_option('display.max_columns', None)
|
|
|
|
pd.set_option('display.max_rows', None)
|
|
|
|
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
|
|
|
|
print(pv)
|
|
|
|
|