full example of training, model selection, and evaluation using the lequa2022 dataset with the new protocols

This commit is contained in:
Alejandro Moreo Fernandez 2022-11-04 15:04:36 +01:00
parent ecd0ad7ec7
commit f2550fdb82
4 changed files with 48 additions and 7 deletions

View File

@ -0,0 +1,26 @@
import numpy as np
from sklearn.linear_model import LogisticRegression
import quapy as qp
from data.datasets import LEQUA2022_SAMPLE_SIZE, fetch_lequa2022
from evaluation import evaluation_report
from method.aggregative import EMQ
from model_selection import GridSearchQ
task = 'T1A'
qp.environ['SAMPLE_SIZE']=LEQUA2022_SAMPLE_SIZE[task]
training, val_generator, test_generator = fetch_lequa2022(task=task)
# define the quantifier
quantifier = EMQ(learner=LogisticRegression())
# model selection
param_grid = {'C': np.logspace(-3, 3, 7), 'class_weight': ['balanced', None]}
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, n_jobs=-1, refit=False, verbose=True)
quantifier = model_selection.fit(training)
# evaluation
report = evaluation_report(quantifier, protocol=test_generator, error_metrics=['mae', 'mrae'], verbose=True)
print(report)

View File

@ -12,6 +12,7 @@ from quapy.data.preprocessing import text2tfidf, reduce_columns
from quapy.data.reader import *
from quapy.util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
TWITTER_SENTIMENT_DATASETS_TEST = ['gasp', 'hcr', 'omd', 'sanders',
'semeval13', 'semeval14', 'semeval15', 'semeval16',
@ -45,6 +46,20 @@ UCI_DATASETS = ['acute.a', 'acute.b',
LEQUA2022_TASKS = ['T1A', 'T1B', 'T2A', 'T2B']
_TXA_SAMPLE_SIZE = 250
_TXB_SAMPLE_SIZE = 1000
LEQUA2022_SAMPLE_SIZE = {
'TXA': _TXA_SAMPLE_SIZE,
'TXB': _TXB_SAMPLE_SIZE,
'T1A': _TXA_SAMPLE_SIZE,
'T1B': _TXB_SAMPLE_SIZE,
'T2A': _TXA_SAMPLE_SIZE,
'T2B': _TXB_SAMPLE_SIZE,
'binary': _TXA_SAMPLE_SIZE,
'multiclass': _TXB_SAMPLE_SIZE
}
def fetch_reviews(dataset_name, tfidf=False, min_df=None, data_home=None, pickle=False) -> Dataset:
"""
@ -578,7 +593,7 @@ def fetch_lequa2022(task, data_home=None):
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
test_samples_path = join(lequa_dir, task, 'public', 'dev_samples')
test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)

View File

@ -11,11 +11,11 @@ def from_name(err_name):
"""
assert err_name in ERROR_NAMES, f'unknown error {err_name}'
callable_error = globals()[err_name]
if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
eps = __check_eps()
def bound_callable_error(y_true, y_pred):
return callable_error(y_true, y_pred, eps)
return bound_callable_error
# if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
# eps = __check_eps()
# def bound_callable_error(y_true, y_pred):
# return callable_error(y_true, y_pred, eps)
# return bound_callable_error
return callable_error

View File

@ -41,7 +41,7 @@ def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
true_prevs, estim_prevs = [], []
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total()) if verbose else protocol():
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total(), desc='predicting') if verbose else protocol():
estim_prevs.append(quantification_fn(sample_instances))
true_prevs.append(sample_prev)