import bug fixed

This commit is contained in:
Alejandro Moreo Fernandez 2021-01-12 09:35:49 +01:00
parent 2ec3400d15
commit 3e07feda3c
2 changed files with 10 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import quapy as qp
from typing import Union, Callable, Iterable
from data import LabelledCollection
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
from method.base import BaseQuantifier
from util import temp_seed
import numpy as np
@ -38,14 +38,18 @@ def artificial_sampling_prediction(
with temp_seed(random_seed):
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
if isinstance(model, AggregativeQuantifier):
if isinstance(model, qp.method.aggregative.AggregativeQuantifier):
print('\tinstance of aggregative-quantifier')
quantification_func = model.aggregate
if isinstance(model, AggregativeProbabilisticQuantifier):
if isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
print('\t\tinstance of probabilitstic-aggregative-quantifier')
preclassified_instances = model.posterior_probabilities(test.instances)
else:
print('\t\tinstance of hard-aggregative-quantifier')
preclassified_instances = model.classify(test.instances)
test = LabelledCollection(preclassified_instances, test.labels)
else:
print('\t\tinstance of base-quantifier')
quantification_func = model.quantify
def _predict_prevalences(index):

View File

@ -44,9 +44,9 @@ def quantification_models():
__C_range = np.logspace(-4, 5, 10)
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
#yield 'cc', qp.method.aggregative.CC(newLR()), lr_params
yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
#yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
#yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params
#yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
def result_path(dataset_name, model_name, optim_metric):
@ -69,7 +69,7 @@ if __name__ == '__main__':
np.random.seed(0)
for dataset_name in ['hcr']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
for dataset_name in ['sanders']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True)
benchmark_devel.stats()