full example of training, model selection, and evaluation using the lequa2022 dataset with the new protocols
This commit is contained in:
parent
ecd0ad7ec7
commit
f2550fdb82
|
@ -0,0 +1,26 @@
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
import quapy as qp
|
||||||
|
from data.datasets import LEQUA2022_SAMPLE_SIZE, fetch_lequa2022
|
||||||
|
from evaluation import evaluation_report
|
||||||
|
from method.aggregative import EMQ
|
||||||
|
from model_selection import GridSearchQ
|
||||||
|
|
||||||
|
|
||||||
|
task = 'T1A'
|
||||||
|
|
||||||
|
qp.environ['SAMPLE_SIZE']=LEQUA2022_SAMPLE_SIZE[task]
|
||||||
|
training, val_generator, test_generator = fetch_lequa2022(task=task)
|
||||||
|
|
||||||
|
# define the quantifier
|
||||||
|
quantifier = EMQ(learner=LogisticRegression())
|
||||||
|
|
||||||
|
# model selection
|
||||||
|
param_grid = {'C': np.logspace(-3, 3, 7), 'class_weight': ['balanced', None]}
|
||||||
|
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, n_jobs=-1, refit=False, verbose=True)
|
||||||
|
quantifier = model_selection.fit(training)
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
report = evaluation_report(quantifier, protocol=test_generator, error_metrics=['mae', 'mrae'], verbose=True)
|
||||||
|
|
||||||
|
print(report)
|
|
@ -12,6 +12,7 @@ from quapy.data.preprocessing import text2tfidf, reduce_columns
|
||||||
from quapy.data.reader import *
|
from quapy.data.reader import *
|
||||||
from quapy.util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
|
from quapy.util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
|
||||||
|
|
||||||
|
|
||||||
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
|
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
|
||||||
TWITTER_SENTIMENT_DATASETS_TEST = ['gasp', 'hcr', 'omd', 'sanders',
|
TWITTER_SENTIMENT_DATASETS_TEST = ['gasp', 'hcr', 'omd', 'sanders',
|
||||||
'semeval13', 'semeval14', 'semeval15', 'semeval16',
|
'semeval13', 'semeval14', 'semeval15', 'semeval16',
|
||||||
|
@ -45,6 +46,20 @@ UCI_DATASETS = ['acute.a', 'acute.b',
|
||||||
|
|
||||||
LEQUA2022_TASKS = ['T1A', 'T1B', 'T2A', 'T2B']
|
LEQUA2022_TASKS = ['T1A', 'T1B', 'T2A', 'T2B']
|
||||||
|
|
||||||
|
_TXA_SAMPLE_SIZE = 250
|
||||||
|
_TXB_SAMPLE_SIZE = 1000
|
||||||
|
|
||||||
|
LEQUA2022_SAMPLE_SIZE = {
|
||||||
|
'TXA': _TXA_SAMPLE_SIZE,
|
||||||
|
'TXB': _TXB_SAMPLE_SIZE,
|
||||||
|
'T1A': _TXA_SAMPLE_SIZE,
|
||||||
|
'T1B': _TXB_SAMPLE_SIZE,
|
||||||
|
'T2A': _TXA_SAMPLE_SIZE,
|
||||||
|
'T2B': _TXB_SAMPLE_SIZE,
|
||||||
|
'binary': _TXA_SAMPLE_SIZE,
|
||||||
|
'multiclass': _TXB_SAMPLE_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def fetch_reviews(dataset_name, tfidf=False, min_df=None, data_home=None, pickle=False) -> Dataset:
|
def fetch_reviews(dataset_name, tfidf=False, min_df=None, data_home=None, pickle=False) -> Dataset:
|
||||||
"""
|
"""
|
||||||
|
@ -578,7 +593,7 @@ def fetch_lequa2022(task, data_home=None):
|
||||||
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
|
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
|
||||||
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
|
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
|
||||||
|
|
||||||
test_samples_path = join(lequa_dir, task, 'public', 'dev_samples')
|
test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
|
||||||
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
|
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
|
||||||
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
|
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,11 @@ def from_name(err_name):
|
||||||
"""
|
"""
|
||||||
assert err_name in ERROR_NAMES, f'unknown error {err_name}'
|
assert err_name in ERROR_NAMES, f'unknown error {err_name}'
|
||||||
callable_error = globals()[err_name]
|
callable_error = globals()[err_name]
|
||||||
if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
|
# if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
|
||||||
eps = __check_eps()
|
# eps = __check_eps()
|
||||||
def bound_callable_error(y_true, y_pred):
|
# def bound_callable_error(y_true, y_pred):
|
||||||
return callable_error(y_true, y_pred, eps)
|
# return callable_error(y_true, y_pred, eps)
|
||||||
return bound_callable_error
|
# return bound_callable_error
|
||||||
return callable_error
|
return callable_error
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='
|
||||||
|
|
||||||
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
|
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
|
||||||
true_prevs, estim_prevs = [], []
|
true_prevs, estim_prevs = [], []
|
||||||
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total()) if verbose else protocol():
|
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total(), desc='predicting') if verbose else protocol():
|
||||||
estim_prevs.append(quantification_fn(sample_instances))
|
estim_prevs.append(quantification_fn(sample_instances))
|
||||||
true_prevs.append(sample_prev)
|
true_prevs.append(sample_prev)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue