forked from moreo/QuaPy
import bug fixed
This commit is contained in:
parent
2ec3400d15
commit
3e07feda3c
|
@ -1,6 +1,6 @@
|
||||||
|
import quapy as qp
|
||||||
from typing import Union, Callable, Iterable
|
from typing import Union, Callable, Iterable
|
||||||
from data import LabelledCollection
|
from data import LabelledCollection
|
||||||
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
|
|
||||||
from method.base import BaseQuantifier
|
from method.base import BaseQuantifier
|
||||||
from util import temp_seed
|
from util import temp_seed
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -38,14 +38,18 @@ def artificial_sampling_prediction(
|
||||||
with temp_seed(random_seed):
|
with temp_seed(random_seed):
|
||||||
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
|
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
|
||||||
|
|
||||||
if isinstance(model, AggregativeQuantifier):
|
if isinstance(model, qp.method.aggregative.AggregativeQuantifier):
|
||||||
|
print('\tinstance of aggregative-quantifier')
|
||||||
quantification_func = model.aggregate
|
quantification_func = model.aggregate
|
||||||
if isinstance(model, AggregativeProbabilisticQuantifier):
|
if isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
|
||||||
|
print('\t\tinstance of probabilitstic-aggregative-quantifier')
|
||||||
preclassified_instances = model.posterior_probabilities(test.instances)
|
preclassified_instances = model.posterior_probabilities(test.instances)
|
||||||
else:
|
else:
|
||||||
|
print('\t\tinstance of hard-aggregative-quantifier')
|
||||||
preclassified_instances = model.classify(test.instances)
|
preclassified_instances = model.classify(test.instances)
|
||||||
test = LabelledCollection(preclassified_instances, test.labels)
|
test = LabelledCollection(preclassified_instances, test.labels)
|
||||||
else:
|
else:
|
||||||
|
print('\t\tinstance of base-quantifier')
|
||||||
quantification_func = model.quantify
|
quantification_func = model.quantify
|
||||||
|
|
||||||
def _predict_prevalences(index):
|
def _predict_prevalences(index):
|
||||||
|
|
|
@ -44,9 +44,9 @@ def quantification_models():
|
||||||
__C_range = np.logspace(-4, 5, 10)
|
__C_range = np.logspace(-4, 5, 10)
|
||||||
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
|
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
|
||||||
#yield 'cc', qp.method.aggregative.CC(newLR()), lr_params
|
#yield 'cc', qp.method.aggregative.CC(newLR()), lr_params
|
||||||
yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
|
#yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
|
||||||
#yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params
|
#yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params
|
||||||
#yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
|
yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
|
||||||
|
|
||||||
|
|
||||||
def result_path(dataset_name, model_name, optim_metric):
|
def result_path(dataset_name, model_name, optim_metric):
|
||||||
|
@ -69,7 +69,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
for dataset_name in ['hcr']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
|
for dataset_name in ['sanders']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
|
||||||
|
|
||||||
benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True)
|
benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True)
|
||||||
benchmark_devel.stats()
|
benchmark_devel.stats()
|
||||||
|
|
Loading…
Reference in New Issue