adding uci_experiments to examples folder
This commit is contained in:
parent
4904475d26
commit
67906f6f2d
3
TODO.txt
3
TODO.txt
|
@ -1,3 +1,6 @@
|
|||
ensembles seem to be broken; they have an internal model selection which takes the parameters, but since quapy now
|
||||
works with protocols it would need to know the validation set in order to pass something like
|
||||
"protocol: APP(val, etc.)"
|
||||
sample_size should not be mandatory when qp.environ['SAMPLE_SIZE'] has been specified
|
||||
clean all the cumbersome methods that have to be implemented for new quantifiers (e.g., n_classes_ prop, etc.)
|
||||
make truly parallel the GridSearchQ
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import quapy as qp
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from classification.methods import LowRankLogisticRegression
|
||||
from quapy.method.meta import QuaNet
|
||||
from quapy.protocol import APP
|
||||
from quapy.method.aggregative import CC, ACC, PCC, PACC, MAX, MS, MS2, EMQ, HDy, newSVMAE
|
||||
from quapy.method.meta import EHDy
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import itertools
|
||||
import argparse
|
||||
import torch
|
||||
import shutil
|
||||
|
||||
|
||||
N_JOBS = -1
|
||||
CUDA_N_JOBS = 2
|
||||
ENSEMBLE_N_JOBS = -1
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
|
||||
|
||||
def newLR():
|
||||
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
||||
|
||||
|
||||
def calibratedLR():
|
||||
return CalibratedClassifierCV(LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1))
|
||||
|
||||
|
||||
__C_range = np.logspace(-3, 3, 7)
|
||||
lr_params = {'classifier__C': __C_range, 'classifier__class_weight': [None, 'balanced']}
|
||||
svmperf_params = {'classifier__C': __C_range}
|
||||
|
||||
|
||||
def quantification_models():
|
||||
yield 'cc', CC(newLR()), lr_params
|
||||
yield 'acc', ACC(newLR()), lr_params
|
||||
yield 'pcc', PCC(newLR()), lr_params
|
||||
yield 'pacc', PACC(newLR()), lr_params
|
||||
yield 'MAX', MAX(newLR()), lr_params
|
||||
yield 'MS', MS(newLR()), lr_params
|
||||
yield 'MS2', MS2(newLR()), lr_params
|
||||
yield 'sldc', EMQ(newLR(), recalib='platt'), lr_params
|
||||
yield 'svmmae', newSVMAE(), svmperf_params
|
||||
yield 'hdy', HDy(newLR()), lr_params
|
||||
|
||||
|
||||
def quantification_cuda_models():
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
learner = LowRankLogisticRegression()
|
||||
yield 'quanet', QuaNet(learner, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
|
||||
|
||||
def evaluate_experiment(true_prevalences, estim_prevalences):
|
||||
print('\nEvaluation Metrics:\n' + '=' * 22)
|
||||
for eval_measure in [qp.error.mae, qp.error.mrae]:
|
||||
err = eval_measure(true_prevalences, estim_prevalences)
|
||||
print(f'\t{eval_measure.__name__}={err:.4f}')
|
||||
print()
|
||||
|
||||
|
||||
def result_path(path, dataset_name, model_name, run, optim_loss):
|
||||
return os.path.join(path, f'{dataset_name}-{model_name}-run{run}-{optim_loss}.pkl')
|
||||
|
||||
|
||||
def is_already_computed(dataset_name, model_name, run, optim_loss):
|
||||
return os.path.exists(result_path(args.results, dataset_name, model_name, run, optim_loss))
|
||||
|
||||
|
||||
def save_results(dataset_name, model_name, run, optim_loss, *results):
|
||||
rpath = result_path(args.results, dataset_name, model_name, run, optim_loss)
|
||||
qp.util.create_parent_dir(rpath)
|
||||
with open(rpath, 'wb') as foo:
|
||||
pickle.dump(tuple(results), foo, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def run(experiment):
|
||||
optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
|
||||
if dataset_name in ['acute.a', 'acute.b', 'iris.1']: return
|
||||
|
||||
collection = qp.datasets.fetch_UCILabelledCollection(dataset_name)
|
||||
for run, data in enumerate(qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=1)):
|
||||
if is_already_computed(dataset_name, model_name, run=run, optim_loss=optim_loss):
|
||||
print(f'result for dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5 already computed.')
|
||||
continue
|
||||
|
||||
print(f'running dataset={dataset_name} model={model_name} loss={optim_loss} run={run+1}/5')
|
||||
# model selection (hyperparameter optimization for a quantification-oriented loss)
|
||||
train, test = data.train_test
|
||||
train, val = train.split_stratified()
|
||||
if hyperparams is not None:
|
||||
model_selection = qp.model_selection.GridSearchQ(
|
||||
deepcopy(model),
|
||||
param_grid=hyperparams,
|
||||
protocol=APP(val, n_prevalences=21, repeats=25),
|
||||
error=optim_loss,
|
||||
refit=True,
|
||||
timeout=60*60,
|
||||
verbose=True
|
||||
)
|
||||
model_selection.fit(data.training)
|
||||
model = model_selection.best_model()
|
||||
best_params = model_selection.best_params_
|
||||
else:
|
||||
model.fit(data.training)
|
||||
best_params = {}
|
||||
|
||||
# model evaluation
|
||||
true_prevalences, estim_prevalences = qp.evaluation.prediction(
|
||||
model,
|
||||
protocol=APP(test, n_prevalences=21, repeats=100)
|
||||
)
|
||||
test_true_prevalence = data.test.prevalence()
|
||||
|
||||
evaluate_experiment(true_prevalences, estim_prevalences)
|
||||
save_results(dataset_name, model_name, run, optim_loss,
|
||||
true_prevalences, estim_prevalences,
|
||||
data.training.prevalence(), test_true_prevalence,
|
||||
best_params)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
|
||||
parser.add_argument('results', metavar='RESULT_PATH', type=str,
|
||||
help='path to the directory where to store the results')
|
||||
parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str, default='../svm_perf_quantification',
|
||||
help='path to the directory with svmperf')
|
||||
parser.add_argument('--checkpointdir', metavar='PATH', type=str, default='./checkpoint',
|
||||
help='path to the directory where to dump QuaNet checkpoints')
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f'Result folder: {args.results}')
|
||||
np.random.seed(0)
|
||||
|
||||
qp.environ['SVMPERF_HOME'] = args.svmperfpath
|
||||
|
||||
optim_losses = ['mae']
|
||||
datasets = qp.datasets.UCI_DATASETS[:4]
|
||||
|
||||
models = quantification_models()
|
||||
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=N_JOBS)
|
||||
|
||||
models = quantification_cuda_models()
|
||||
qp.util.parallel(run, itertools.product(optim_losses, datasets, models), n_jobs=CUDA_N_JOBS)
|
||||
|
||||
shutil.rmtree(args.checkpointdir, ignore_errors=True)
|
|
@ -19,17 +19,16 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
|
||||
def __init__(self, n_components=100, **kwargs):
|
||||
self.n_components = n_components
|
||||
self.learner = LogisticRegression(**kwargs)
|
||||
self.classifier = LogisticRegression(**kwargs)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
def get_params(self):
|
||||
"""
|
||||
Get hyper-parameters for this estimator.
|
||||
|
||||
:param deep: compatibility with sklearn
|
||||
:return: a dictionary with parameter names mapped to their values
|
||||
"""
|
||||
params = {'n_components': self.n_components}
|
||||
params.update(self.learner.get_params(deep))
|
||||
params.update(self.classifier.get_params())
|
||||
return params
|
||||
|
||||
def set_params(self, **params):
|
||||
|
@ -44,7 +43,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
if 'n_components' in params_:
|
||||
self.n_components = params_['n_components']
|
||||
del params_['n_components']
|
||||
self.learner.set_params(**params_)
|
||||
self.classifier.set_params(**params_)
|
||||
|
||||
def fit(self, X, y):
|
||||
"""
|
||||
|
@ -60,8 +59,8 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
if nF > self.n_components:
|
||||
self.pca = TruncatedSVD(self.n_components).fit(X)
|
||||
X = self.transform(X)
|
||||
self.learner.fit(X, y)
|
||||
self.classes_ = self.learner.classes_
|
||||
self.classifier.fit(X, y)
|
||||
self.classes_ = self.classifier.classes_
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
|
@ -73,7 +72,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
instances in `X`
|
||||
"""
|
||||
X = self.transform(X)
|
||||
return self.learner.predict(X)
|
||||
return self.classifier.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
"""
|
||||
|
@ -83,7 +82,7 @@ class LowRankLogisticRegression(BaseEstimator):
|
|||
:return: array-like of shape `(n_samples, n_classes)` with the posterior probabilities
|
||||
"""
|
||||
X = self.transform(X)
|
||||
return self.learner.predict_proba(X)
|
||||
return self.classifier.predict_proba(X)
|
||||
|
||||
def transform(self, X):
|
||||
"""
|
||||
|
|
|
@ -207,7 +207,7 @@ def fetch_UCIDataset(dataset_name, data_home=None, test_split=0.3, verbose=False
|
|||
return Dataset(*data.split_stratified(1 - test_split, random_state=0))
|
||||
|
||||
|
||||
def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) -> Dataset:
|
||||
def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) -> LabelledCollection:
|
||||
"""
|
||||
Loads a UCI collection as an instance of :class:`quapy.data.base.LabelledCollection`, as used in
|
||||
`Pérez-Gállego, P., Quevedo, J. R., & del Coz, J. J. (2017).
|
||||
|
@ -223,7 +223,7 @@ def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) ->
|
|||
|
||||
>>> import quapy as qp
|
||||
>>> collection = qp.datasets.fetch_UCILabelledCollection("yeast")
|
||||
>>> for data in qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
|
||||
>>> for data in qp.domains.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
|
||||
>>> ...
|
||||
|
||||
The list of valid dataset names can be accessed in `quapy.data.datasets.UCI_DATASETS`
|
||||
|
@ -233,7 +233,7 @@ def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) ->
|
|||
~/quay_data/ directory)
|
||||
:param test_split: proportion of documents to be included in the test set. The rest conforms the training set
|
||||
:param verbose: set to True (default is False) to get information (from the UCI ML repository) about the datasets
|
||||
:return: a :class:`quapy.data.base.Dataset` instance
|
||||
:return: a :class:`quapy.data.base.LabelledCollection` instance
|
||||
"""
|
||||
|
||||
assert dataset_name in UCI_DATASETS, \
|
||||
|
|
|
@ -444,24 +444,28 @@ class EMQ(AggregativeProbabilisticQuantifier):
|
|||
|
||||
def __init__(self, classifier: BaseEstimator, exact_train_prev=True, recalib=None):
|
||||
self.classifier = classifier
|
||||
self.non_calibrated = classifier
|
||||
self.exact_train_prev = exact_train_prev
|
||||
self.recalib = recalib
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_classifier=True):
|
||||
if self.recalib is not None:
|
||||
if self.recalib == 'nbvs':
|
||||
self.classifier = NBVSCalibration(self.classifier)
|
||||
self.classifier = NBVSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'bcts':
|
||||
self.classifier = BCTSCalibration(self.classifier)
|
||||
self.classifier = BCTSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'ts':
|
||||
self.classifier = TSCalibration(self.classifier)
|
||||
self.classifier = TSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'vs':
|
||||
self.classifier = VSCalibration(self.classifier)
|
||||
self.classifier = VSCalibration(self.non_calibrated)
|
||||
elif self.recalib == 'platt':
|
||||
self.classifier = CalibratedClassifierCV(self.classifier, ensemble=False)
|
||||
else:
|
||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||
'"nbvs", "bcts", "ts", and "vs".')
|
||||
self.recalib = None
|
||||
else:
|
||||
self.classifier = self.non_calibrated
|
||||
self.classifier, _ = _training_helper(self.classifier, data, fit_classifier, ensure_probabilistic=True)
|
||||
if self.exact_train_prev:
|
||||
self.train_prevalence = F.prevalence_from_labels(data.labels, self.classes_)
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.nn.functional import relu
|
|||
from quapy.protocol import UPP
|
||||
from quapy.method.aggregative import *
|
||||
from quapy.util import EarlyStop
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class QuaNetTrainer(BaseQuantifier):
|
||||
|
@ -28,7 +29,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
>>>
|
||||
>>> # load the kindle dataset as text, and convert words to numerical indexes
|
||||
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
|
||||
>>> qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||
>>> qp.domains.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||
>>>
|
||||
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
|
||||
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
|
||||
|
@ -263,15 +264,19 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}')
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return {**self.classifier.get_params(), **self.quanet_params}
|
||||
classifier_params = self.classifier.get_params()
|
||||
classifier_params = {'classifier__'+k:v for k,v in classifier_params.items()}
|
||||
return {**classifier_params, **self.quanet_params}
|
||||
|
||||
def set_params(self, **parameters):
|
||||
learner_params = {}
|
||||
for key, val in parameters.items():
|
||||
if key in self.quanet_params:
|
||||
self.quanet_params[key] = val
|
||||
elif key.startswith('classifier__'):
|
||||
learner_params[key.replace('classifier__', '')] = val
|
||||
else:
|
||||
learner_params[key] = val
|
||||
raise ValueError('unknown parameter ', key)
|
||||
self.classifier.set_params(**learner_params)
|
||||
|
||||
def __check_params_colision(self, quanet_params, learner_params):
|
||||
|
|
|
@ -56,7 +56,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
def _sout(self, msg):
|
||||
if self.verbose:
|
||||
print(f'[{self.__class__.__name__}]: {msg}')
|
||||
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
|
||||
|
||||
def __check_error(self, error):
|
||||
if error in qp.error.QUANTIFICATION_ERROR:
|
||||
|
|
Loading…
Reference in New Issue