main updated

This commit is contained in:
Lorenzo Volpi 2023-06-08 15:20:11 +02:00
parent da827943d6
commit b969234244
1 changed files with 9 additions and 8 deletions

View File

@ -2,7 +2,7 @@ import pandas as pd
import quapy as qp import quapy as qp
from quapy.method.aggregative import SLD from quapy.method.aggregative import SLD
from quapy.protocol import APP from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression from sklearn.svm import SVC
import quacc.evaluation as eval import quacc.evaluation as eval
from quacc.estimator import AccuracyEstimator from quacc.estimator import AccuracyEstimator
@ -17,23 +17,24 @@ pd.set_option("display.float_format", "{:.4f}".format)
def test_2(dataset_name): def test_2(dataset_name):
train, test = get_dataset(dataset_name) train, test = get_dataset(dataset_name)
model = LogisticRegression() model = SVC(probability=True)
print(f"fitting model {model.__class__.__name__}...", end=" ") print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
model.fit(*train.Xy) model.fit(*train.Xy)
print("fit") print("fit")
qmodel = SLD(LogisticRegression()) qmodel = SLD(SVC(probability=True))
estimator = AccuracyEstimator(model, qmodel) estimator = AccuracyEstimator(model, qmodel)
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ") print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True)
estimator.fit(train) estimator.fit(train)
print("fit") print("fit")
n_prevalences = 21 n_prevalences = 21
repreats = 1000 repreats = 1000
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats) protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
print( f"Tests:\n\ print(
f"Tests:\n\
protocol={protocol.__class__.__name__}\n\ protocol={protocol.__class__.__name__}\n\
n_prevalences={n_prevalences}\n\ n_prevalences={n_prevalences}\n\
repreats={repreats}\n\ repreats={repreats}\n\
@ -49,9 +50,9 @@ def test_2(dataset_name):
def main(): def main():
for dataset_name in [ for dataset_name in [
"hp",
"imdb", "imdb",
"spambase", # "hp",
# "spambase",
]: ]:
print(dataset_name) print(dataset_name)
test_2(dataset_name) test_2(dataset_name)