forked from moreo/QuaPy
full example of training, model selection, and evaluation using the lequa2022 dataset with the new protocols
This commit is contained in:
parent
ecd0ad7ec7
commit
f2550fdb82
|
@ -0,0 +1,26 @@
|
|||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
from data.datasets import LEQUA2022_SAMPLE_SIZE, fetch_lequa2022
|
||||
from evaluation import evaluation_report
|
||||
from method.aggregative import EMQ
|
||||
from model_selection import GridSearchQ
|
||||
|
||||
|
||||
task = 'T1A'
|
||||
|
||||
qp.environ['SAMPLE_SIZE']=LEQUA2022_SAMPLE_SIZE[task]
|
||||
training, val_generator, test_generator = fetch_lequa2022(task=task)
|
||||
|
||||
# define the quantifier
|
||||
quantifier = EMQ(learner=LogisticRegression())
|
||||
|
||||
# model selection
|
||||
param_grid = {'C': np.logspace(-3, 3, 7), 'class_weight': ['balanced', None]}
|
||||
model_selection = GridSearchQ(quantifier, param_grid, protocol=val_generator, n_jobs=-1, refit=False, verbose=True)
|
||||
quantifier = model_selection.fit(training)
|
||||
|
||||
# evaluation
|
||||
report = evaluation_report(quantifier, protocol=test_generator, error_metrics=['mae', 'mrae'], verbose=True)
|
||||
|
||||
print(report)
|
|
@ -12,6 +12,7 @@ from quapy.data.preprocessing import text2tfidf, reduce_columns
|
|||
from quapy.data.reader import *
|
||||
from quapy.util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
|
||||
|
||||
|
||||
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
|
||||
TWITTER_SENTIMENT_DATASETS_TEST = ['gasp', 'hcr', 'omd', 'sanders',
|
||||
'semeval13', 'semeval14', 'semeval15', 'semeval16',
|
||||
|
@ -45,6 +46,20 @@ UCI_DATASETS = ['acute.a', 'acute.b',
|
|||
|
||||
LEQUA2022_TASKS = ['T1A', 'T1B', 'T2A', 'T2B']
|
||||
|
||||
_TXA_SAMPLE_SIZE = 250
|
||||
_TXB_SAMPLE_SIZE = 1000
|
||||
|
||||
LEQUA2022_SAMPLE_SIZE = {
|
||||
'TXA': _TXA_SAMPLE_SIZE,
|
||||
'TXB': _TXB_SAMPLE_SIZE,
|
||||
'T1A': _TXA_SAMPLE_SIZE,
|
||||
'T1B': _TXB_SAMPLE_SIZE,
|
||||
'T2A': _TXA_SAMPLE_SIZE,
|
||||
'T2B': _TXB_SAMPLE_SIZE,
|
||||
'binary': _TXA_SAMPLE_SIZE,
|
||||
'multiclass': _TXB_SAMPLE_SIZE
|
||||
}
|
||||
|
||||
|
||||
def fetch_reviews(dataset_name, tfidf=False, min_df=None, data_home=None, pickle=False) -> Dataset:
|
||||
"""
|
||||
|
@ -578,7 +593,7 @@ def fetch_lequa2022(task, data_home=None):
|
|||
val_true_prev_path = join(lequa_dir, task, 'public', 'dev_prevalences.txt')
|
||||
val_gen = SamplesFromDir(val_samples_path, val_true_prev_path, load_fn=load_fn)
|
||||
|
||||
test_samples_path = join(lequa_dir, task, 'public', 'dev_samples')
|
||||
test_samples_path = join(lequa_dir, task, 'public', 'test_samples')
|
||||
test_true_prev_path = join(lequa_dir, task, 'public', 'test_prevalences.txt')
|
||||
test_gen = SamplesFromDir(test_samples_path, test_true_prev_path, load_fn=load_fn)
|
||||
|
||||
|
|
|
@ -11,11 +11,11 @@ def from_name(err_name):
|
|||
"""
|
||||
assert err_name in ERROR_NAMES, f'unknown error {err_name}'
|
||||
callable_error = globals()[err_name]
|
||||
if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
|
||||
eps = __check_eps()
|
||||
def bound_callable_error(y_true, y_pred):
|
||||
return callable_error(y_true, y_pred, eps)
|
||||
return bound_callable_error
|
||||
# if err_name in QUANTIFICATION_ERROR_SMOOTH_NAMES:
|
||||
# eps = __check_eps()
|
||||
# def bound_callable_error(y_true, y_pred):
|
||||
# return callable_error(y_true, y_pred, eps)
|
||||
# return bound_callable_error
|
||||
return callable_error
|
||||
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='
|
|||
|
||||
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
|
||||
true_prevs, estim_prevs = [], []
|
||||
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total()) if verbose else protocol():
|
||||
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total(), desc='predicting') if verbose else protocol():
|
||||
estim_prevs.append(quantification_fn(sample_instances))
|
||||
true_prevs.append(sample_prev)
|
||||
|
||||
|
|
Loading…
Reference in New Issue