QuAcc/quacc/method/model_selection.py

301 lines
9.9 KiB
Python

import itertools
from copy import deepcopy
from time import time
from typing import Callable, Union
import quapy as qp
from quapy.data import LabelledCollection
from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP, AbstractProtocol, OnLabelledCollectionProtocol
from sklearn.base import BaseEstimator
import quacc as qc
import quacc.error
from quacc.data import ExtendedCollection
from quacc.evaluation import evaluate
from quacc.logger import SubLogger
from quacc.method.base import (
BaseAccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
MultiClassAccuracyEstimator,
)
class GridSearchAE(BaseAccuracyEstimator):
def __init__(
self,
model: BaseAccuracyEstimator,
param_grid: dict,
protocol: AbstractProtocol,
error: Union[Callable, str] = qc.error.maccd,
refit=True,
# timeout=-1,
# n_jobs=None,
verbose=False,
):
self.model = model
self.param_grid = self.__normalize_params(param_grid)
self.protocol = protocol
self.refit = refit
# self.timeout = timeout
# self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
def _sout(self, msg):
if self.verbose:
print(f"[{self.__class__.__name__}]: {msg}")
def __normalize_params(self, params):
__remap = {}
for key in params.keys():
k, delim, sub_key = key.partition("__")
if delim and k == "q":
__remap[key] = f"quantifier__{sub_key}"
return {(__remap[k] if k in __remap else k): v for k, v in params.items()}
def __check_error(self, error):
if error in qc.error.ACCURACY_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qc.error.from_name(error)
elif hasattr(error, "__call__"):
self.error = error
else:
raise ValueError(
f"unexpected error type; must either be a callable function or a str representing\n"
f"the name of an error function in {qc.error.ACCURACY_ERROR_NAMES}"
)
def fit(self, training: LabelledCollection):
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
:param training: the training set on which to optimize the hyperparameters
:return: self
"""
params_keys = list(self.param_grid.keys())
params_values = list(self.param_grid.values())
protocol = self.protocol
self.param_scores_ = {}
self.best_score_ = None
tinit = time()
hyper = [
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
]
# self._sout(f"starting model selection with {self.n_jobs =}")
self._sout("starting model selection")
scores = [self.__params_eval(params, training) for params in hyper]
for params, score, model in scores:
if score is not None:
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
self.best_model_ = model
self.param_scores_[str(params)] = score
else:
self.param_scores_[str(params)] = "timeout"
tend = time() - tinit
if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work")
self._sout(
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
log = SubLogger.logger()
log.debug(
f"[{self.model.__class__.__name__}] "
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout("refitting on the whole development set")
self.best_model_.fit(training + protocol.get_labelled_collection())
else:
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not '
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
return self
def __params_eval(self, params, training):
protocol = self.protocol
error = self.error
# if self.timeout > 0:
# def handler(signum, frame):
# raise TimeoutError()
# signal.signal(signal.SIGALRM, handler)
tinit = time()
# if self.timeout > 0:
# signal.alarm(self.timeout)
try:
model = deepcopy(self.model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
# print({k: v for k, v in model.get_params().items() if k in params})
model.fit(training)
score = evaluate(model, protocol=protocol, error_metric=error)
ttime = time() - tinit
self._sout(
f"hyperparams={params}\t got score {score:.5f} [took {ttime:.4f}s]"
)
# if self.timeout > 0:
# signal.alarm(0)
# except TimeoutError:
# self._sout(f"timeout ({self.timeout}s) reached for config {params}")
# score = None
except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid")
raise e
except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:")
self._sout(f"\tException: {e}")
score = None
return params, score, model
def extend(self, coll: LabelledCollection, pred_proba=None) -> ExtendedCollection:
assert hasattr(self, "best_model_"), "quantify called before fit"
return self.best_model().extend(coll, pred_proba=pred_proba)
def estimate(self, instances, ext=False):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
:param instances: sample contanining the instances
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
"""
assert hasattr(self, "best_model_"), "quantify called before fit"
return self.best_model().estimate(instances, ext=ext)
def set_params(self, **parameters):
"""Sets the hyper-parameters to explore.
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
"""
self.param_grid = parameters
def get_params(self, deep=True):
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
:param deep: Unused
:return: the dictionary `param_grid`
"""
return self.param_grid
def best_model(self):
"""
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
of hyper-parameters that minimized the error function.
:return: a trained quantifier
"""
if hasattr(self, "best_model_"):
return self.best_model_
raise ValueError("best_model called before fit")
class MCAEgsq(MultiClassAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier, quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0)
self.quantifier = GridSearchQ(
deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(self.e_train)
return self
class BQAEgsq(BinaryQuantifierAccuracyEstimator):
def __init__(
self,
classifier: BaseEstimator,
quantifier: BaseAccuracyEstimator,
param_grid: dict,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
self.param_grid = param_grid
self.refit = refit
self.timeout = timeout
self.n_jobs = n_jobs
self.verbose = verbose
self.error = error
super().__init__(classifier=classifier, quantifier=quantifier)
def fit(self, train: LabelledCollection):
self.e_train = self.extend(train)
self.n_classes = self.e_train.n_classes
self.e_trains = self.e_train.split_by_pred()
self.quantifiers = []
for e_train in self.e_trains:
t_train, t_val = e_train.split_stratified(0.6, random_state=0)
quantifier = GridSearchQ(
model=deepcopy(self.quantifier),
param_grid=self.param_grid,
protocol=UPP(t_val, repeats=100),
error=self.error,
refit=self.refit,
timeout=self.timeout,
n_jobs=self.n_jobs,
verbose=self.verbose,
).fit(t_train)
self.quantifiers.append(quantifier)
return self