import bug fixed
This commit is contained in:
parent
2ec3400d15
commit
3e07feda3c
|
@ -1,6 +1,6 @@
|
|||
import quapy as qp
|
||||
from typing import Union, Callable, Iterable
|
||||
from data import LabelledCollection
|
||||
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
|
||||
from method.base import BaseQuantifier
|
||||
from util import temp_seed
|
||||
import numpy as np
|
||||
|
@ -38,14 +38,18 @@ def artificial_sampling_prediction(
|
|||
with temp_seed(random_seed):
|
||||
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
|
||||
|
||||
if isinstance(model, AggregativeQuantifier):
|
||||
if isinstance(model, qp.method.aggregative.AggregativeQuantifier):
|
||||
print('\tinstance of aggregative-quantifier')
|
||||
quantification_func = model.aggregate
|
||||
if isinstance(model, AggregativeProbabilisticQuantifier):
|
||||
if isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
|
||||
print('\t\tinstance of probabilitstic-aggregative-quantifier')
|
||||
preclassified_instances = model.posterior_probabilities(test.instances)
|
||||
else:
|
||||
print('\t\tinstance of hard-aggregative-quantifier')
|
||||
preclassified_instances = model.classify(test.instances)
|
||||
test = LabelledCollection(preclassified_instances, test.labels)
|
||||
else:
|
||||
print('\t\tinstance of base-quantifier')
|
||||
quantification_func = model.quantify
|
||||
|
||||
def _predict_prevalences(index):
|
||||
|
|
|
@ -44,9 +44,9 @@ def quantification_models():
|
|||
__C_range = np.logspace(-4, 5, 10)
|
||||
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
|
||||
#yield 'cc', qp.method.aggregative.CC(newLR()), lr_params
|
||||
yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
|
||||
#yield 'acc', qp.method.aggregative.ACC(newLR()), lr_params
|
||||
#yield 'pcc', qp.method.aggregative.PCC(newLR()), lr_params
|
||||
#yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
|
||||
yield 'pacc', qp.method.aggregative.PACC(newLR()), lr_params
|
||||
|
||||
|
||||
def result_path(dataset_name, model_name, optim_metric):
|
||||
|
@ -69,7 +69,7 @@ if __name__ == '__main__':
|
|||
|
||||
np.random.seed(0)
|
||||
|
||||
for dataset_name in ['hcr']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
|
||||
for dataset_name in ['sanders']: # qp.datasets.TWITTER_SENTIMENT_DATASETS:
|
||||
|
||||
benchmark_devel = qp.datasets.fetch_twitter(dataset_name, for_model_selection=True, min_df=5, pickle=True)
|
||||
benchmark_devel.stats()
|
||||
|
|
Loading…
Reference in New Issue