From 7b2d3cb7f1bc6ff69a60e15b156c9bc69730964e Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Sat, 11 Feb 2023 10:08:31 +0100 Subject: [PATCH] example using svmperf --- examples/one_vs_all_svmperf.py | 54 +++++++++++++++++++++++++++++++++ quapy/classification/svmperf.py | 3 ++ 2 files changed, 57 insertions(+) create mode 100644 examples/one_vs_all_svmperf.py diff --git a/examples/one_vs_all_svmperf.py b/examples/one_vs_all_svmperf.py new file mode 100644 index 0000000..8bf38bd --- /dev/null +++ b/examples/one_vs_all_svmperf.py @@ -0,0 +1,54 @@ +import quapy as qp +from quapy.method.aggregative import MS2, OneVsAllAggregative, OneVsAllGeneric, SVMQ +from quapy.method.base import getOneVsAll +from quapy.model_selection import GridSearchQ +from quapy.protocol import USimplexPP +from sklearn.linear_model import LogisticRegression +import numpy as np + +""" +In this example, we will create a quantifier for tweet sentiment analysis considering three classes: negative, neutral, +and positive. We will use a one-vs-all approach using a binary quantifier for demonstration purposes. +""" + +qp.environ['SAMPLE_SIZE'] = 100 +qp.environ['N_JOBS'] = -1 +qp.environ['SVMPERF_HOME'] = '../svm_perf_quantification' + +""" +Any binary quantifier can be turned into a single-label quantifier by means of getOneVsAll function. +This function returns an instance of OneVsAll quantifier. Actually, it either returns the subclass OneVsAllGeneric +when the quantifier is an instance of BaseQuantifier, and it returns OneVsAllAggregative when the quantifier is +an instance of AggregativeQuantifier. Although OneVsAllGeneric works in all cases, using OneVsAllAggregative has +some additional advantages (namely, all the advantages that AggregativeQuantifiers enjoy, i.e., faster predictions +during evaluation). +""" +quantifier = getOneVsAll(SVMQ()) +print(f'the quantifier is an instance of {quantifier.__class__.__name__}') + +# load a ternary dataset +train_modsel, val = qp.datasets.fetch_twitter('hcr', for_model_selection=True, pickle=True).train_test + +""" +model selection: for this example, we are relying on the USimplexPP protocol, i.e., a variant of the +artificial-prevalence protocol that generates random samples (100 in this case) for randomly picked priors +from the unit simplex. The priors are sampled using the Kraemer algorithm. Note this is in contrast to the +standard APP protocol, that instead explores a prefixed grid of prevalence values. +""" +param_grid = { + 'binary_quantifier__classifier__C': np.logspace(-2,2,5), # classifier-dependent hyperparameter +} +print('starting model selection') +model_selection = GridSearchQ(quantifier, param_grid, protocol=USimplexPP(val), verbose=True, refit=False) +quantifier = model_selection.fit(train_modsel).best_model() + +print('training on the whole training set') +train, test = qp.datasets.fetch_twitter('hcr', for_model_selection=False, pickle=True).train_test +quantifier.fit(train) + +# evaluation +mae = qp.evaluation.evaluate(quantifier, protocol=USimplexPP(test), error_metric='mae') + +print(f'MAE = {mae:.4f}') + + diff --git a/quapy/classification/svmperf.py b/quapy/classification/svmperf.py index 176b102..7921725 100644 --- a/quapy/classification/svmperf.py +++ b/quapy/classification/svmperf.py @@ -44,6 +44,9 @@ class SVMperf(BaseEstimator, ClassifierMixin): assert list(parameters.keys()) == ['C'], 'currently, only the C parameter is supported' self.C = parameters['C'] + def get_params(self, deep=True): + return {'C': self.C} + def fit(self, X, y): """ Trains the SVM for the multivariate performance loss