diff --git a/quacc/main.py b/quacc/main.py index 16caeb7..93af1d8 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -2,7 +2,7 @@ import pandas as pd import quapy as qp from quapy.method.aggregative import SLD from quapy.protocol import APP -from sklearn.linear_model import LogisticRegression +from sklearn.svm import SVC import quacc.evaluation as eval from quacc.estimator import AccuracyEstimator @@ -17,23 +17,24 @@ pd.set_option("display.float_format", "{:.4f}".format) def test_2(dataset_name): train, test = get_dataset(dataset_name) - model = LogisticRegression() + model = SVC(probability=True) - print(f"fitting model {model.__class__.__name__}...", end=" ") + print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True) model.fit(*train.Xy) print("fit") - qmodel = SLD(LogisticRegression()) + qmodel = SLD(SVC(probability=True)) estimator = AccuracyEstimator(model, qmodel) - print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ") + print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True) estimator.fit(train) print("fit") n_prevalences = 21 repreats = 1000 protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats) - print( f"Tests:\n\ + print( + f"Tests:\n\ protocol={protocol.__class__.__name__}\n\ n_prevalences={n_prevalences}\n\ repreats={repreats}\n\ @@ -49,9 +50,9 @@ def test_2(dataset_name): def main(): for dataset_name in [ - "hp", "imdb", - "spambase", + # "hp", + # "spambase", ]: print(dataset_name) test_2(dataset_name)