1
0
Fork 0

import bug fixed

This commit is contained in:
Alejandro Moreo Fernandez 2021-01-12 09:35:49 +01:00
parent 2ec3400d15
commit 3e07feda3c
2 changed files with 10 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import quapy as qp
from typing import Union, Callable, Iterable from typing import Union, Callable, Iterable
from data import LabelledCollection from data import LabelledCollection
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
from method.base import BaseQuantifier from method.base import BaseQuantifier
from util import temp_seed from util import temp_seed
import numpy as np import numpy as np
@ -38,14 +38,18 @@ def artificial_sampling_prediction(
with temp_seed(random_seed): with temp_seed(random_seed):
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions)) indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
if isinstance(model, AggregativeQuantifier): if isinstance(model, qp.method.aggregative.AggregativeQuantifier):
print('\tinstance of aggregative-quantifier')
quantification_func = model.aggregate quantification_func = model.aggregate
if isinstance(model, AggregativeProbabilisticQuantifier): if isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
print('\t\tinstance of probabilitstic-aggregative-quantifier')
preclassified_instances = model.posterior_probabilities(test.instances) preclassified_instances = model.posterior_probabilities(test.instances)
else: else:
print('\t\tinstance of hard-aggregative-quantifier')
preclassified_instances = model.classify(test.instances) preclassified_instances = model.classify(test.instances)
test = LabelledCollection(preclassified_instances, test.labels) test = LabelledCollection(preclassified_instances, test.labels)
else: else:
print('\t\tinstance of base-quantifier')
quantification_func = model.quantify quantification_func = model.quantify
def _predict_prevalences(index): def _predict_prevalences(index):

View File

@ -44,9 +44,9 @@ def quantification_models():
__C_range = np.logspace(-4, 5, 10) __C_range = np.logspace(-4, 5, 10)
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']} lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
#yield 'cc', qp.method.aggregative.CC(newLR()), lr_params #yield 'cc', qp.method.aggregative.CC(newLR()), lr_params
yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params #yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
#yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params #yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params
#yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
def result_path(dataset_name, model_name, optim_metric): def result_path(dataset_name, model_name, optim_metric):
@ -69,7 +69,7 @@ if __name__ == '__main__':
np.random.seed(0) np.random.seed(0)
for dataset_name in ['hcr']: # qp.datasets.TWITTER_SENTIMENT_DATASETS: for dataset_name in ['sanders']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True) benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True)
benchmark_devel.stats() benchmark_devel.stats()