QuAcc/quacc/main.py

42 lines
1.2 KiB
Python
Raw Normal View History

import traceback
import quacc.evaluation.method as method
DATASET = "imdb"
2023-10-19 02:42:54 +02:00
OUTPUT_FILE = "out_" + DATASET + ".md"
TARGETS = {
"rcv1" : [
'C12',
'C13', 'C15', 'C151', 'C1511', 'C152', 'C17', 'C172',
'C18', 'C181', 'C21', 'C24', 'C31', 'C42', 'CCAT'
'E11', 'E12', 'E21', 'E211', 'E212', 'E41', 'E51', 'ECAT',
'G15', 'GCAT', 'GCRIM', 'GDIP', 'GPOL', 'GVIO', 'GVOTE', 'GWEA',
'GWELF', 'M11', 'M12', 'M13', 'M131', 'M132', 'M14', 'M141',
'M142', 'M143', 'MCAT'
],
"spambase": ["default"],
"imdb": ["default"],
}
2023-05-17 14:02:29 +02:00
2023-09-24 02:21:18 +02:00
def estimate_comparison():
open(OUTPUT_FILE, "w").close()
targets = TARGETS[DATASET]
for target in targets:
try:
er = method.evaluate_comparison(DATASET, target=target)
er.target = target
with open(OUTPUT_FILE, "a") as f:
2023-10-19 03:00:04 +02:00
f.write(er.to_md(["acc"], ["f1"]))
except Exception:
traceback.print_exc()
2023-09-24 02:21:18 +02:00
# print(df.to_latex(float_format="{:.4f}".format))
# print(utils.avg_group_report(df).to_latex(float_format="{:.4f}".format))
2023-09-24 02:21:18 +02:00
def main():
estimate_comparison()
2023-05-17 14:02:29 +02:00
if __name__ == "__main__":
2023-09-24 02:21:18 +02:00
main()