diff --git a/.gitignore b/.gitignore
index 1ae9719..f450568 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,5 @@ lipton_bbse/__pycache__/*
elsahar19_rca/__pycache__/*
*.coverage
.coverage
-scp_sync.py
\ No newline at end of file
+scp_sync.py
+out/*
\ No newline at end of file
diff --git a/TODO.html b/TODO.html
new file mode 100644
index 0000000..ddfdc17
--- /dev/null
+++ b/TODO.html
@@ -0,0 +1,55 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
aggiungere media tabelle
+
plot; 3 tipi (appunti + email + garg)
+
sistemare kfcv baseline
+
aggiungere metodo con CC oltre SLD
+
prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
+
variare parametro recalibration in SLD
+
+
+
+
+
+
\ No newline at end of file
diff --git a/TODO.md b/TODO.md
index a3a7321..2f6d846 100644
--- a/TODO.md
+++ b/TODO.md
@@ -1,8 +1,6 @@
-- aggiungere media tabelle
-- plot
- - 3 tipi (vedi appunti + garg)
-- sistemare kfcv baseline
-- aggiungere metodo con CC oltre SLD
-- prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50
- poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
-- variare parametro recalibration in SLD
\ No newline at end of file
+- [ ] aggiungere media tabelle
+- [ ] plot; 3 tipi (appunti + email + garg)
+- [ ] sistemare kfcv baseline
+- [ ] aggiungere metodo con CC oltre SLD
+- [x] prendere classe più popolosa di rcv1, togliere negativi fino a raggiungere 50/50; poi fare subsampling con 9 training prvalences (da 0.1-0.9 a 0.9-0.1)
+- [ ] variare parametro recalibration in SLD
\ No newline at end of file
diff --git a/quacc/dataset.py b/quacc/dataset.py
index a4e1735..c63552f 100644
--- a/quacc/dataset.py
+++ b/quacc/dataset.py
@@ -1,26 +1,102 @@
-from typing import Tuple
+import math
+from typing import List
+
import numpy as np
-from quapy.data.base import LabelledCollection
import quapy as qp
+from quapy.data.base import LabelledCollection
from sklearn.conftest import fetch_rcv1
TRAIN_VAL_PROP = 0.5
-def get_imdb(**kwargs) -> Tuple[LabelledCollection]:
- train, test = qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
- train, validation = train.split_stratified(
- train_prop=TRAIN_VAL_PROP, random_state=0
- )
- return train, validation, test
+class DatasetSample:
+ def __init__(
+ self,
+ train: LabelledCollection,
+ validation: LabelledCollection,
+ test: LabelledCollection,
+ ):
+ self.train = train
+ self.validation = validation
+ self.test = test
+
+ @property
+ def train_prev(self):
+ return self.train.prevalence()
+
+ @property
+ def validation_prev(self):
+ return self.validation.prevalence()
+
+ @property
+ def prevs(self):
+ return {"train": self.train_prev, "validation": self.validation_prev}
-def get_spambase(**kwargs) -> Tuple[LabelledCollection]:
- train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
- train, validation = train.split_stratified(
- train_prop=TRAIN_VAL_PROP, random_state=0
- )
- return train, validation, test
+class Dataset:
+ def __init__(self, name, n_prevalences=9, target=None):
+ self._name = name
+ self._target = target
+ self.n_prevs = n_prevalences
+
+ def __spambase(self):
+ return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test
+
+ def __imdb(self):
+ return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test
+
+ def __rcv1(self):
+ n_train = 23149
+ available_targets = ["CCAT", "GCAT", "MCAT"]
+
+ if self._target is None or self._target not in available_targets:
+ raise ValueError("Invalid target")
+
+ dataset = fetch_rcv1()
+ target_index = np.where(dataset.target_names == self._target)[0]
+ all_train_d, test_d = dataset.data[:n_train, :], dataset.data[n_train:, :]
+ labels = dataset.target[:, target_index].toarray().flatten()
+ all_train_l, test_l = labels[:n_train], labels[n_train:]
+ all_train = LabelledCollection(all_train_d, all_train_l, classes=[0, 1])
+ test = LabelledCollection(test_d, test_l, classes=[0, 1])
+
+ return all_train, test
+
+ def get(self) -> List[DatasetSample]:
+ all_train, test = {
+ "spambase": self.__spambase,
+ "imdb": self.__imdb,
+ "rcv1": self.__rcv1,
+ }[self._name]()
+
+ # resample all_train set to have (0.5, 0.5) prevalence
+ at_positives = np.sum(all_train.y)
+ all_train = all_train.sampling(
+ min(at_positives, len(all_train) - at_positives) * 2, 0.5, random_state=0
+ )
+
+ # sample prevalences
+ prevalences = np.linspace(0.0, 1.0, num=self.n_prevs + 1, endpoint=False)[1:]
+ at_size = min(math.floor(len(all_train) * 0.5 / p) for p in prevalences)
+ datasets = []
+ for p in prevalences:
+ all_train_sampled = all_train.sampling(at_size, p, random_state=0)
+ train, validation = all_train_sampled.split_stratified(
+ train_prop=TRAIN_VAL_PROP, random_state=0
+ )
+ datasets.append(DatasetSample(train, validation, test))
+
+ return datasets
+
+ def __call__(self):
+ return self.get()
+
+ @property
+ def name(self):
+ if self._name == "rcv1":
+ return f"{self._name}_{self._target}"
+ else:
+ return self._name
# >>> fetch_rcv1().target_names
@@ -39,33 +115,30 @@ def get_spambase(**kwargs) -> Tuple[LabelledCollection]:
# 'M142', 'M143', 'MCAT'], dtype=object)
-def get_rcv1(target = "default", **kwargs):
- sample_size = qp.environ["SAMPLE_SIZE"]
- n_train = 23149
+def rcv1_info():
dataset = fetch_rcv1()
+ n_train = 23149
- if target == "default":
- target = "C12"
-
- if target not in dataset.target_names:
- raise ValueError("Invalid target")
-
- def dataset_split(data, labels, classes=[0, 1]) -> Tuple[LabelledCollection]:
- all_train_d, test_d = data[:n_train, :], data[n_train:, :]
- all_train_l, test_l = labels[:n_train], labels[n_train:]
- all_train = LabelledCollection(all_train_d, all_train_l, classes=classes)
- test = LabelledCollection(test_d, test_l, classes=classes)
- train, validation = all_train.split_stratified(
- train_prop=TRAIN_VAL_PROP, random_state=0
+ targets = []
+ for target in range(103):
+ train_t_prev = np.average(dataset.target[:n_train, target].toarray().flatten())
+ test_t_prev = np.average(dataset.target[n_train:, target].toarray().flatten())
+ targets.append(
+ (
+ dataset.target_names[target],
+ {
+ "train": (1.0 - train_t_prev, train_t_prev),
+ "test": (1.0 - test_t_prev, test_t_prev),
+ },
+ )
)
- return train, validation, test
- target_index = np.where(dataset.target_names == target)[0]
- target_labels = dataset.target[:, target_index].toarray().flatten()
+ targets.sort(key=lambda t: t[1]["train"][1])
+ for n, d in targets:
+ print(f"{n}:")
+ for k, (fp, tp) in d.items():
+ print(f"\t{k}: {fp:.4f}, {tp:.4f}")
- if np.sum(target_labels[n_train:]) < sample_size:
- raise ValueError("Target has too few positive samples")
- d = dataset_split(dataset.data, target_labels, classes=[0, 1])
-
- return d
+if __name__ == "__main__":
+ rcv1_info()
diff --git a/quacc/environ.py b/quacc/environ.py
new file mode 100644
index 0000000..cc2f13c
--- /dev/null
+++ b/quacc/environ.py
@@ -0,0 +1,31 @@
+from pathlib import Path
+
+defalut_env = {
+ "DATASET_NAME": "rcv1",
+ "DATASET_TARGET": "CCAT",
+ "COMP_ESTIMATORS": [
+ "OUR_BIN_SLD",
+ "OUR_MUL_SLD",
+ "KFCV",
+ "ATC_MC",
+ "ATC_NE",
+ "DOC_FEAT",
+ # "RCA",
+ # "RCA_STAR",
+ ],
+ "DATASET_N_PREVS": 9,
+ "OUT_DIR": Path("out"),
+ "PLOT_OUT_DIR": Path("out/plot"),
+ "PROTOCOL_N_PREVS": 21,
+ "PROTOCOL_REPEATS": 100,
+ "SAMPLE_SIZE": 1000,
+}
+
+
+class Environ:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ self.__setattr__(k, v)
+
+
+env = Environ(**defalut_env)
diff --git a/quacc/evaluation/baseline.py b/quacc/evaluation/baseline.py
index ce36045..f4e969d 100644
--- a/quacc/evaluation/baseline.py
+++ b/quacc/evaluation/baseline.py
@@ -1,29 +1,28 @@
from statistics import mean
-from typing import Dict
import numpy as np
-from quapy.data import LabelledCollection
-from sklearn.base import BaseEstimator
-from sklearn.model_selection import cross_validate
import sklearn.metrics as metrics
+from quapy.data import LabelledCollection
from quapy.protocol import (
AbstractStochasticSeededProtocol,
OnLabelledCollectionProtocol,
)
-
-from .report import EvaluationReport
+from sklearn.base import BaseEstimator
+from sklearn.model_selection import cross_validate
import elsahar19_rca.rca as rca
import garg22_ATC.ATC_helper as atc
import guillory21_doc.doc as doc
import jiang18_trustscore.trustscore as trustscore
+from .report import EvaluationReport
+
def kfcv(
- c_model: BaseEstimator,
+ c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
- predict_method="predict"
+ predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
@@ -42,12 +41,12 @@ def kfcv(
meta_f1 = abs(f1_score - metrics.f1_score(test.y, test_preds))
report.append_row(
test.prevalence(),
- acc_score=(1. - acc_score),
+ acc_score=(1.0 - acc_score),
f1_score=f1_score,
acc=meta_acc,
f1=meta_f1,
)
-
+
return report
@@ -63,7 +62,7 @@ def reference(
test_probs = c_model_predict(test.X)
test_preds = np.argmax(test_probs, axis=-1)
report.append_row(
- test.prevalence(),
+ test.prevalence(),
acc_score=(1 - metrics.accuracy_score(test.y, test_preds)),
f1_score=metrics.f1_score(test.y, test_preds),
)
diff --git a/quacc/evaluation/comp.py b/quacc/evaluation/comp.py
new file mode 100644
index 0000000..ccc4e18
--- /dev/null
+++ b/quacc/evaluation/comp.py
@@ -0,0 +1,91 @@
+import multiprocessing
+import time
+from typing import List
+
+import pandas as pd
+import quapy as qp
+from quapy.protocol import APP
+from sklearn.linear_model import LogisticRegression
+
+from quacc.dataset import Dataset
+from quacc.environ import env
+from quacc.evaluation import baseline, method
+from quacc.evaluation.report import DatasetReport, EvaluationReport
+
+qp.environ["SAMPLE_SIZE"] = env.SAMPLE_SIZE
+
+pd.set_option("display.float_format", "{:.4f}".format)
+
+
+class CompEstimator:
+ __dict = {
+ "OUR_BIN_SLD": method.evaluate_bin_sld,
+ "OUR_MUL_SLD": method.evaluate_mul_sld,
+ "KFCV": baseline.kfcv,
+ "ATC_MC": baseline.atc_mc,
+ "ATC_NE": baseline.atc_ne,
+ "DOC_FEAT": baseline.doc_feat,
+ "RCA": baseline.rca_score,
+ "RCA_STAR": baseline.rca_star_score,
+ }
+
+ def __class_getitem__(cls, e: str | List[str]):
+ if isinstance(e, str):
+ try:
+ return cls.__dict[e]
+ except KeyError:
+ raise KeyError(f"Invalid estimator: estimator {e} does not exist")
+ elif isinstance(e, list):
+ try:
+ return [cls.__dict[est] for est in e]
+ except KeyError as ke:
+ raise KeyError(
+ f"Invalid estimator: estimator {ke.args[0]} does not exist"
+ )
+
+
+CE = CompEstimator
+
+
+def fit_and_estimate(_estimate, train, validation, test):
+ model = LogisticRegression()
+
+ model.fit(*train.Xy)
+ protocol = APP(
+ test, n_prevalences=env.PROTOCOL_N_PREVS, repeats=env.PROTOCOL_REPEATS
+ )
+ start = time.time()
+ result = _estimate(model, validation, protocol)
+ end = time.time()
+ print(f"{_estimate.__name__}: {end-start:.2f}s")
+
+ return {
+ "name": _estimate.__name__,
+ "result": result,
+ "time": end - start,
+ }
+
+
+def evaluate_comparison(
+ dataset: Dataset, estimators=["OUR_BIN_SLD", "OUR_MUL_SLD"]
+) -> EvaluationReport:
+ with multiprocessing.Pool(8) as pool:
+ dr = DatasetReport(dataset.name)
+ for d in dataset():
+ print(f"train prev.: {d.train_prev}")
+ start = time.time()
+ tasks = [(estim, d.train, d.validation, d.test) for estim in CE[estimators]]
+ results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
+ results = list(map(lambda r: r.get(), results))
+ er = EvaluationReport.combine_reports(
+ *list(map(lambda r: r["result"], results)), name=dataset.name
+ )
+ times = {r["name"]: r["time"] for r in results}
+ end = time.time()
+ times["tot"] = end - start
+ er.times = times
+ er.train_prevs = d.prevs
+ dr.add(er)
+ print()
+
+ return dr
diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py
index 0c69ba1..e42f203 100644
--- a/quacc/evaluation/method.py
+++ b/quacc/evaluation/method.py
@@ -1,20 +1,11 @@
-import multiprocessing
-import time
-
-import pandas as pd
-import quapy as qp
from quapy.data import LabelledCollection
from quapy.protocol import (
- APP,
AbstractStochasticSeededProtocol,
OnLabelledCollectionProtocol,
)
from sklearn.base import BaseEstimator
-from sklearn.linear_model import LogisticRegression
import quacc.error as error
-import quacc.evaluation.baseline as baseline
-from quacc.dataset import get_imdb, get_rcv1, get_spambase
from quacc.evaluation.report import EvaluationReport
from ..estimator import (
@@ -23,13 +14,6 @@ from ..estimator import (
MulticlassAccuracyEstimator,
)
-qp.environ["SAMPLE_SIZE"] = 100
-
-pd.set_option("display.float_format", "{:.4f}".format)
-
-n_prevalences = 21
-repreats = 100
-
def estimate(
estimator: AccuracyEstimator,
@@ -61,11 +45,11 @@ def evaluation_report(
acc_score = error.acc(estim_prev)
f1_score = error.f1(estim_prev)
report.append_row(
- base_prev,
- acc_score=1. - acc_score,
- acc = abs(error.acc(true_prev) - acc_score),
+ base_prev,
+ acc_score=1.0 - acc_score,
+ acc=abs(error.acc(true_prev) - acc_score),
f1_score=f1_score,
- f1=abs(error.f1(true_prev) - f1_score)
+ f1=abs(error.f1(true_prev) - f1_score),
)
return report
@@ -77,7 +61,7 @@ def evaluate(
protocol: AbstractStochasticSeededProtocol,
method: str,
):
- estimator : AccuracyEstimator = {
+ estimator: AccuracyEstimator = {
"bin": BinaryQuantifierAccuracyEstimator,
"mul": MulticlassAccuracyEstimator,
}[method](c_model)
@@ -85,65 +69,17 @@ def evaluate(
return evaluation_report(estimator, protocol, method)
-def evaluate_binary(model, validation, protocol):
- return evaluate(model, validation, protocol, "bin")
+def evaluate_bin_sld(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "bin")
-def evaluate_multiclass(model, validation, protocol):
- return evaluate(model, validation, protocol, "mul")
-
-
-def fit_and_estimate(_estimate, train, validation, test):
- model = LogisticRegression()
-
- model.fit(*train.Xy)
- protocol = APP(test, n_prevalences=n_prevalences, repeats=repreats)
- start = time.time()
- result = _estimate(model, validation, protocol)
- end = time.time()
-
- return {
- "name": _estimate.__name__,
- "result": result,
- "time": end - start,
- }
-
-
-def evaluate_comparison(dataset: str, **kwargs) -> EvaluationReport:
- train, validation, test = {
- "spambase": get_spambase,
- "imdb": get_imdb,
- "rcv1": get_rcv1,
- }[dataset](**kwargs)
-
- for k,v in kwargs.items():
- print(k, ":", v)
-
- prevs = {
- "train": train.prevalence(),
- "validation": validation.prevalence(),
- }
-
- start = time.time()
- with multiprocessing.Pool(8) as pool:
- estimators = [
- evaluate_binary,
- evaluate_multiclass,
- baseline.kfcv,
- baseline.atc_mc,
- baseline.atc_ne,
- baseline.doc_feat,
- baseline.rca_score,
- baseline.rca_star_score,
- ]
- tasks = [(estim, train, validation, test) for estim in estimators]
- results = [pool.apply_async(fit_and_estimate, t) for t in tasks]
- results = list(map(lambda r: r.get(), results))
- er = EvaluationReport.combine_reports(*list(map(lambda r: r["result"], results)))
- times = {r["name"]:r["time"] for r in results}
- end = time.time()
- times["tot"] = end - start
- er.times = times
- er.prevs = prevs
-
- return er
+def evaluate_mul_sld(
+ c_model: BaseEstimator,
+ validation: LabelledCollection,
+ protocol: AbstractStochasticSeededProtocol,
+) -> EvaluationReport:
+ return evaluate(c_model, validation, protocol, "mul")
diff --git a/quacc/evaluation/report.py b/quacc/evaluation/report.py
index 696bc85..0236b98 100644
--- a/quacc/evaluation/report.py
+++ b/quacc/evaluation/report.py
@@ -1,143 +1,122 @@
-from typing import Tuple
+import math
import statistics as stats
+from typing import List, Tuple
+
import numpy as np
import pandas as pd
-
-def _fmt_line(s):
- return f"> {s} \n"
+from quacc import plot
+from quacc.utils import fmt_line_md
class EvaluationReport:
def __init__(self, prefix=None):
- self.base = []
- self.dict = {}
- self._grouped = False
- self._grouped_base = []
- self._grouped_dict = {}
- self._dataframe = None
- self.prefix = prefix if prefix is not None else "default"
- self._times = {}
- self._prevs = {}
- self._target = "default"
+ self._prevs = []
+ self._dict = {}
+ self._g_prevs = None
+ self._g_dict = None
+ self.name = prefix if prefix is not None else "default"
+ self.times = {}
+ self.train_prevs = {}
+ self.target = "default"
def append_row(self, base: np.ndarray | Tuple, **row):
if isinstance(base, np.ndarray):
base = tuple(base.tolist())
- self.base.append(base)
+ self._prevs.append(base)
for k, v in row.items():
- if (k, self.prefix) in self.dict:
- self.dict[(k, self.prefix)].append(v)
+ if (k, self.name) in self._dict:
+ self._dict[(k, self.name)].append(v)
else:
- self.dict[(k, self.prefix)] = [v]
- self._grouped = False
- self._dataframe = None
+ self._dict[(k, self.name)] = [v]
+ self._g_prevs = None
@property
def columns(self):
- return self.dict.keys()
+ return self._dict.keys()
- @property
- def grouped(self):
- if self._grouped:
- return self._grouped_dict
+ def groupby_prevs(self, metric: str = None):
+ if self._g_dict is None:
+ self._g_prevs = []
+ self._g_dict = {k: [] for k in self._dict.keys()}
- self._grouped_base = []
- self._grouped_dict = {k: [] for k in self.dict.keys()}
+ last_end = 0
+ for ind, bp in enumerate(self._prevs):
+ if ind < (len(self._prevs) - 1) and bp == self._prevs[ind + 1]:
+ continue
- last_end = 0
- for ind, bp in enumerate(self.base):
- if ind < (len(self.base) - 1) and bp == self.base[ind + 1]:
- continue
+ self._g_prevs.append(bp)
+ for col in self._dict.keys():
+ self._g_dict[col].append(
+ stats.mean(self._dict[col][last_end : ind + 1])
+ )
- self._grouped_base.append(bp)
- for col in self.dict.keys():
- self._grouped_dict[col].append(
- stats.mean(self.dict[col][last_end : ind + 1])
- )
+ last_end = ind + 1
- last_end = ind + 1
+ filtered_g_dict = self._g_dict
+ if metric is not None:
+ filtered_g_dict = {
+ c1: ls for ((c0, c1), ls) in self._g_dict.items() if c0 == metric
+ }
- self._grouped = True
- return self._grouped_dict
+ return self._g_prevs, filtered_g_dict
- @property
- def gbase(self):
- self.grouped
- return self._grouped_base
-
- def get_dataframe(self, metrics=None):
- if self._dataframe is None:
- self_columns = sorted(self.columns, key=lambda c: c[0])
- self._dataframe = pd.DataFrame(
- self.grouped,
- index=self.gbase,
- columns=pd.MultiIndex.from_tuples(self_columns),
- )
-
- df = pd.DataFrame(self._dataframe)
- if metrics is not None:
- df = df.drop(
- [(c0, c1) for (c0, c1) in df.columns if c0 not in metrics], axis=1
- )
-
- if len(set(k0 for k0, k1 in df.columns)) == 1:
- df = df.droplevel(0, axis=1)
-
- return df
-
- def merge(self, other):
- if not all(v1 == v2 for v1, v2 in zip(self.base, other.base)):
- raise ValueError("other has not same base prevalences of self")
-
- if len(set(self.dict.keys()).intersection(set(other.dict.keys()))) > 0:
- raise ValueError("self and other have matching keys")
-
- report = EvaluationReport()
- report.base = self.base
- report.dict = self.dict | other.dict
- return report
-
- @property
- def times(self):
- return self._times
-
- @times.setter
- def times(self, val):
- self._times = val
-
- @property
- def prevs(self):
- return self._prevs
-
- @prevs.setter
- def prevs(self, val):
- self._prevs = val
-
- @property
- def target(self):
- return self._target
-
- @target.setter
- def target(self, val):
- self._target = val
+ def get_dataframe(self, metric="acc"):
+ g_prevs, g_dict = self.groupby_prevs(metric=metric)
+ return pd.DataFrame(
+ g_dict,
+ index=g_prevs,
+ columns=g_dict.keys(),
+ )
def to_md(self, *metrics):
- res = _fmt_line("target: " + self.target)
- for k, v in self.prevs.items():
- res += _fmt_line(f"{k}: {str(v)}")
+ res = ""
+ for k, v in self.train_prevs.items():
+ res += fmt_line_md(f"{k}: {str(v)}")
for k, v in self.times.items():
- res += _fmt_line(f"{k}: {v:.3f}s")
+ res += fmt_line_md(f"{k}: {v:.3f}s")
res += "\n"
for m in metrics:
- res += self.get_dataframe(metrics=m).to_html() + "\n\n"
+ res += self.get_dataframe(metric=m).to_html() + "\n\n"
return res
+ def merge(self, other):
+ if not all(v1 == v2 for v1, v2 in zip(self._prevs, other._prevs)):
+ raise ValueError("other has not same base prevalences of self")
+
+ if len(set(self._dict.keys()).intersection(set(other._dict.keys()))) > 0:
+ raise ValueError("self and other have matching keys")
+
+ report = EvaluationReport()
+ report._prevs = self._prevs
+ report._dict = self._dict | other._dict
+ return report
+
@staticmethod
- def combine_reports(*args):
+ def combine_reports(*args, name="default"):
er = args[0]
for r in args[1:]:
er = er.merge(r)
+ er.name = name
return er
+
+
+class DatasetReport:
+ def __init__(self, name):
+ self.name = name
+ self.ers: List[EvaluationReport] = []
+
+ def add(self, er: EvaluationReport):
+ self.ers.append(er)
+
+ def to_md(self, *metrics):
+ res = f"{self.name}\n\n"
+ for er in self.ers:
+ res += f"{er.to_md(*metrics)}\n\n"
+
+ return res
+
+ def __iter__(self):
+ return (er for er in self.ers)
diff --git a/quacc/main.py b/quacc/main.py
index fea3800..c900a98 100644
--- a/quacc/main.py
+++ b/quacc/main.py
@@ -1,33 +1,16 @@
-import traceback
-import quacc.evaluation.method as method
+import quacc.evaluation.comp as comp
+from quacc.dataset import Dataset
+from quacc.environ import env
-DATASET = "imdb"
-OUTPUT_FILE = "out_" + DATASET + ".md"
-TARGETS = {
- "rcv1" : [
- 'C12',
- 'C13', 'C15', 'C151', 'C1511', 'C152', 'C17', 'C172',
- 'C18', 'C181', 'C21', 'C24', 'C31', 'C42', 'CCAT'
- 'E11', 'E12', 'E21', 'E211', 'E212', 'E41', 'E51', 'ECAT',
- 'G15', 'GCAT', 'GCRIM', 'GDIP', 'GPOL', 'GVIO', 'GVOTE', 'GWEA',
- 'GWELF', 'M11', 'M12', 'M13', 'M131', 'M132', 'M14', 'M141',
- 'M142', 'M143', 'MCAT'
- ],
- "spambase": ["default"],
- "imdb": ["default"],
-}
def estimate_comparison():
- open(OUTPUT_FILE, "w").close()
- targets = TARGETS[DATASET]
- for target in targets:
- try:
- er = method.evaluate_comparison(DATASET, target=target)
- er.target = target
- with open(OUTPUT_FILE, "a") as f:
- f.write(er.to_md(["acc"], ["f1"]))
- except Exception:
- traceback.print_exc()
+ dataset = Dataset(
+ env.DATASET_NAME, target=env.DATASET_TARGET, n_prevalences=env.DATASET_N_PREVS
+ )
+ output_path = env.OUT_DIR / f"{dataset.name}.md"
+ with open(output_path, "w") as f:
+ dr = comp.evaluate_comparison(dataset, estimators=env.COMP_ESTIMATORS)
+ f.write(dr.to_md("acc"))
# print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
diff --git a/quacc/utils.py b/quacc/utils.py
index d38b9f6..d2b61f0 100644
--- a/quacc/utils.py
+++ b/quacc/utils.py
@@ -1,7 +1,8 @@
-
import functools
+
import pandas as pd
+
def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
if len(dfs) < 1:
raise ValueError
@@ -10,15 +11,13 @@ def combine_dataframes(dfs, df_index=[]) -> pd.DataFrame:
df = dfs[0]
for ndf in dfs[1:]:
df = df.join(ndf.set_index(df_index), on=df_index)
-
+
return df
def avg_group_report(df: pd.DataFrame) -> pd.DataFrame:
def _reduce_func(s1, s2):
- return {
- (n1, n2): v + s2[(n1, n2)] for ((n1, n2), v) in s1.items()
- }
+ return {(n1, n2): v + s2[(n1, n2)] for ((n1, n2), v) in s1.items()}
lst = df.to_dict(orient="records")[1:-1]
summed_series = functools.reduce(_reduce_func, lst)
@@ -28,4 +27,8 @@ def avg_group_report(df: pd.DataFrame) -> pd.DataFrame:
for ((n1, n2), v) in summed_series.items()
if n1 != "base"
}
- return pd.DataFrame([avg_report], columns=idx)
\ No newline at end of file
+ return pd.DataFrame([avg_report], columns=idx)
+
+
+def fmt_line_md(s):
+ return f"> {s} \n"