QuAcc/quacc/main.py

137 lines
3.6 KiB
Python
Raw Normal View History

import pandas as pd
import quapy as qp
from quapy.protocol import APP
from sklearn.linear_model import LogisticRegression
2023-05-20 20:23:17 +02:00
import quacc.evaluation as eval
2023-09-24 02:21:18 +02:00
import quacc.baseline as baseline
from quacc.estimator import (
BinaryQuantifierAccuracyEstimator,
MulticlassAccuracyEstimator,
)
2023-05-20 20:23:17 +02:00
2023-09-24 02:21:18 +02:00
from quacc.dataset import get_imdb, get_spambase
2023-05-20 20:23:17 +02:00
qp.environ["SAMPLE_SIZE"] = 100
pd.set_option("display.float_format", "{:.4f}".format)
dataset_name = "imdb"
def estimate_multiclass():
print(dataset_name)
2023-09-24 02:21:18 +02:00
train, validation, test = get_imdb()
2023-06-05 21:54:22 +02:00
model = LogisticRegression()
2023-06-05 21:54:22 +02:00
2023-06-08 15:20:11 +02:00
print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
2023-05-20 20:23:17 +02:00
model.fit(*train.Xy)
2023-06-05 21:54:22 +02:00
print("fit")
estimator = MulticlassAccuracyEstimator(model)
2023-06-05 21:54:22 +02:00
print(
f"fitting qmodel {estimator.q_model.__class__.__name__}...", end=" ", flush=True
)
2023-05-20 20:23:17 +02:00
estimator.fit(train)
2023-06-05 21:54:22 +02:00
print("fit")
n_prevalences = 21
repreats = 1000
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
2023-06-08 15:20:11 +02:00
print(
f"Tests:\n\
2023-06-05 21:54:22 +02:00
protocol={protocol.__class__.__name__}\n\
n_prevalences={n_prevalences}\n\
repreats={repreats}\n\
executing...\n"
)
df = eval.evaluation_report(
estimator,
protocol,
aggregate=True,
)
# print(df.to_latex())
2023-05-20 20:23:17 +02:00
print(df.to_string())
# print(df.to_html())
print()
def estimate_binary():
print(dataset_name)
2023-09-24 02:21:18 +02:00
train, validation, test = get_imdb()
model = LogisticRegression()
print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
model.fit(*train.Xy)
print("fit")
2023-05-20 20:23:17 +02:00
estimator = BinaryQuantifierAccuracyEstimator(model)
print(
f"fitting qmodel {estimator.q_model_0.__class__.__name__}...",
end=" ",
flush=True,
)
estimator.fit(train)
print("fit")
2023-05-20 20:23:17 +02:00
n_prevalences = 21
repreats = 1000
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
print(
f"Tests:\n\
protocol={protocol.__class__.__name__}\n\
n_prevalences={n_prevalences}\n\
repreats={repreats}\n\
executing...\n"
)
df = eval.evaluation_report(
estimator,
protocol,
aggregate=True,
)
# print(df.to_latex(float_format="{:.4f}".format))
print(df.to_string())
# print(df.to_html())
print()
2023-05-17 14:02:29 +02:00
2023-09-24 02:21:18 +02:00
def estimate_comparison():
train, validation, test = get_spambase()
model = LogisticRegression()
model.fit(*train.Xy)
n_prevalences = 21
repreats = 1000
protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
estimator = BinaryQuantifierAccuracyEstimator(model)
estimator.fit(validation)
df = eval.evaluation_report(estimator, protocol)
df_index = [("base", "F"), ("base", "T")]
atc_mc_df = baseline.atc_mc(model, validation, protocol)
atc_ne_df = baseline.atc_ne(model, validation, protocol)
doc_feat_df = baseline.doc_feat(model, validation, protocol)
rca_df = baseline.rca_score(model, validation, protocol)
rca_star_df = baseline.rca_star_score(model, validation, protocol)
bbse_df = baseline.bbse_score(model, validation, protocol)
df = df.join(atc_mc_df.set_index(df_index), on=df_index)
df = df.join(atc_ne_df.set_index(df_index), on=df_index)
df = df.join(doc_feat_df.set_index(df_index), on=df_index)
df = df.join(rca_df.set_index(df_index), on=df_index)
df = df.join(rca_star_df.set_index(df_index), on=df_index)
df = df.join(bbse_df.set_index(df_index), on=df_index)
print(df.to_string())
def main():
estimate_comparison()
2023-05-17 14:02:29 +02:00
if __name__ == "__main__":
2023-09-24 02:21:18 +02:00
main()