Binary quantifier added, support added and tested.

This commit is contained in:
Lorenzo Volpi 2023-07-26 00:38:23 +02:00
parent b969234244
commit 1347ac3c9d
12 changed files with 371 additions and 112 deletions

0
.editorconfig Normal file
View File

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
quavenv/*
*.pdf
quacc/__pycache__/*
tests/__pycache__/*

16
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "main",
"type": "python",
"request": "launch",
"program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py",
"console": "integratedTerminal",
"justMyCode": true
}
]
}

129
poetry.lock generated
View File

@ -1,10 +1,9 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "abstention"
version = "0.1.3.1"
description = "Functions for abstention, calibration and label shift domain adaptation"
category = "main"
optional = false
python-versions = "*"
files = [
@ -20,7 +19,6 @@ scipy = ">=1.1.0"
name = "colorama"
version = "0.4.6"
description = "Cross-platform colored terminal text."
category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
@ -32,7 +30,6 @@ files = [
name = "contourpy"
version = "1.0.7"
description = "Python library for calculating contours of 2D quadrilateral grids"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -107,7 +104,6 @@ test-no-images = ["pytest"]
name = "cycler"
version = "0.11.0"
description = "Composable style cycles"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@ -119,7 +115,6 @@ files = [
name = "fonttools"
version = "4.39.4"
description = "Tools to manipulate font files"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -141,11 +136,21 @@ ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=15.0.0)"]
woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "joblib"
version = "1.2.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -157,7 +162,6 @@ files = [
name = "kiwisolver"
version = "1.4.4"
description = "A fast implementation of the Cassowary constraint solver"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -235,7 +239,6 @@ files = [
name = "matplotlib"
version = "3.7.1"
description = "Python plotting package"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -297,7 +300,6 @@ python-dateutil = ">=2.7"
name = "numpy"
version = "1.24.3"
description = "Fundamental package for array computing in Python"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -335,7 +337,6 @@ files = [
name = "packaging"
version = "23.1"
description = "Core utilities for Python packages"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -345,37 +346,36 @@ files = [
[[package]]
name = "pandas"
version = "2.0.1"
version = "2.0.3"
description = "Powerful data structures for data analysis, time series, and statistics"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
{file = "pandas-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70a996a1d2432dadedbb638fe7d921c88b0cc4dd90374eab51bb33dc6c0c2a12"},
{file = "pandas-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:909a72b52175590debbf1d0c9e3e6bce2f1833c80c76d80bd1aa09188be768e5"},
{file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe7914d8ddb2d54b900cec264c090b88d141a1eed605c9539a187dbc2547f022"},
{file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a514ae436b23a92366fbad8365807fc0eed15ca219690b3445dcfa33597a5cc"},
{file = "pandas-2.0.1-cp310-cp310-win32.whl", hash = "sha256:12bd6618e3cc737c5200ecabbbb5eaba8ab645a4b0db508ceeb4004bb10b060e"},
{file = "pandas-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:2b6fe5f7ce1cba0e74188c8473c9091ead9b293ef0a6794939f8cc7947057abd"},
{file = "pandas-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:00959a04a1d7bbc63d75a768540fb20ecc9e65fd80744c930e23768345a362a7"},
{file = "pandas-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af2449e9e984dfad39276b885271ba31c5e0204ffd9f21f287a245980b0e4091"},
{file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910df06feaf9935d05247db6de452f6d59820e432c18a2919a92ffcd98f8f79b"},
{file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0067f2419f933101bdc6001bcea1d50812afbd367b30943417d67fbb99678"},
{file = "pandas-2.0.1-cp311-cp311-win32.whl", hash = "sha256:7b8395d335b08bc8b050590da264f94a439b4770ff16bb51798527f1dd840388"},
{file = "pandas-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:8db5a644d184a38e6ed40feeb12d410d7fcc36648443defe4707022da127fc35"},
{file = "pandas-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7bbf173d364130334e0159a9a034f573e8b44a05320995127cf676b85fd8ce86"},
{file = "pandas-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c0853d487b6c868bf107a4b270a823746175b1932093b537b9b76c639fc6f7e"},
{file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25e23a03f7ad7211ffa30cb181c3e5f6d96a8e4cb22898af462a7333f8a74eb"},
{file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e09a53a4fe8d6ae2149959a2d02e1ef2f4d2ceb285ac48f74b79798507e468b4"},
{file = "pandas-2.0.1-cp38-cp38-win32.whl", hash = "sha256:a2564629b3a47b6aa303e024e3d84e850d36746f7e804347f64229f8c87416ea"},
{file = "pandas-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:03e677c6bc9cfb7f93a8b617d44f6091613a5671ef2944818469be7b42114a00"},
{file = "pandas-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d099ecaa5b9e977b55cd43cf842ec13b14afa1cfa51b7e1179d90b38c53ce6a"},
{file = "pandas-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a37ee35a3eb6ce523b2c064af6286c45ea1c7ff882d46e10d0945dbda7572753"},
{file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:320b180d125c3842c5da5889183b9a43da4ebba375ab2ef938f57bf267a3c684"},
{file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18d22cb9043b6c6804529810f492ab09d638ddf625c5dea8529239607295cb59"},
{file = "pandas-2.0.1-cp39-cp39-win32.whl", hash = "sha256:90d1d365d77d287063c5e339f49b27bd99ef06d10a8843cf00b1a49326d492c1"},
{file = "pandas-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:99f7192d8b0e6daf8e0d0fd93baa40056684e4b4aaaef9ea78dff34168e1f2f0"},
{file = "pandas-2.0.1.tar.gz", hash = "sha256:19b8e5270da32b41ebf12f0e7165efa7024492e9513fb46fb631c5022ae5709d"},
{file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"},
{file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"},
{file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"},
{file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"},
{file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"},
{file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"},
{file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"},
{file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"},
{file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"},
{file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"},
{file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"},
{file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"},
{file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"},
{file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"},
{file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"},
{file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"},
{file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"},
{file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"},
{file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"},
{file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"},
{file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"},
{file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"},
{file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"},
{file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"},
{file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"},
]
[package.dependencies]
@ -388,7 +388,7 @@ pytz = ">=2020.1"
tzdata = ">=2022.1"
[package.extras]
all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
aws = ["s3fs (>=2021.08.0)"]
clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
@ -407,14 +407,13 @@ plot = ["matplotlib (>=3.6.1)"]
postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
spss = ["pyreadstat (>=1.1.2)"]
sql-other = ["SQLAlchemy (>=1.4.16)"]
test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.6.3)"]
[[package]]
name = "pillow"
version = "9.5.0"
description = "Python Imaging Library (Fork)"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -490,11 +489,25 @@ files = [
docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"]
tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
[[package]]
name = "pluggy"
version = "1.2.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"},
{file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pyparsing"
version = "3.0.9"
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
category = "main"
optional = false
python-versions = ">=3.6.8"
files = [
@ -505,11 +518,30 @@ files = [
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pytest"
version = "7.4.0"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"},
{file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
files = [
@ -524,7 +556,6 @@ six = ">=1.5"
name = "pytz"
version = "2023.3"
description = "World timezone definitions, modern and historical"
category = "main"
optional = false
python-versions = "*"
files = [
@ -536,7 +567,6 @@ files = [
name = "quapy"
version = "0.1.7"
description = "QuaPy: a framework for Quantification in Python"
category = "main"
optional = false
python-versions = ">=3.6, <4"
files = [
@ -557,7 +587,6 @@ xlrd = "*"
name = "scikit-learn"
version = "1.2.2"
description = "A set of python modules for machine learning and data mining"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -600,7 +629,6 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy (
name = "scipy"
version = "1.9.3"
description = "Fundamental algorithms for scientific computing in Python"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
@ -639,7 +667,6 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
files = [
@ -651,7 +678,6 @@ files = [
name = "threadpoolctl"
version = "3.1.0"
description = "threadpoolctl"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
@ -663,7 +689,6 @@ files = [
name = "tqdm"
version = "4.65.0"
description = "Fast, Extensible Progress Meter"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
@ -684,7 +709,6 @@ telegram = ["requests"]
name = "tzdata"
version = "2023.3"
description = "Provider of IANA time zone data"
category = "main"
optional = false
python-versions = ">=2"
files = [
@ -696,7 +720,6 @@ files = [
name = "xlrd"
version = "2.0.1"
description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
@ -712,4 +735,4 @@ test = ["pytest", "pytest-cov"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "811aa60aea2ab4cf9b4bc4ad3e546e8ee1a81d78f15acd35f3f736b5f97512b4"
content-hash = "834ffb619893a1fb006e1b5a3213cc772117c9000e719b95a4478f74fd5d0066"

View File

@ -8,11 +8,15 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
quapy = "^0.1.7"
pandas = "^2.0.3"
[tool.poetry.scripts]
main = "quacc.main:main"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,11 +1,40 @@
from typing import List, Optional
from typing import Any, List, Optional
import numpy as np
import math
import quapy as qp
import scipy.sparse as sp
from quapy.data import LabelledCollection
# Extended classes
#
# 0 ~ True 0
# 1 ~ False 1
# 2 ~ False 0
# 3 ~ True 1
# _____________________
# | | |
# | True 0 | False 1 |
# |__________|__________|
# | | |
# | False 0 | True 1 |
# |__________|__________|
#
class ExClassManager:
@staticmethod
def get_ex(n_classes: int, true_class: int, pred_class: int) -> int:
return true_class * n_classes + pred_class
@staticmethod
def get_pred(n_classes: int, ex_class: int) -> int:
return ex_class % n_classes
@staticmethod
def get_true(n_classes: int, ex_class: int) -> int:
return ex_class // n_classes
class ExtendedCollection(LabelledCollection):
def __init__(
self,
@ -14,7 +43,68 @@ class ExtendedCollection(LabelledCollection):
classes: Optional[List] = None,
):
super().__init__(instances, labels, classes=classes)
def split_by_pred(self):
_ncl = int(math.sqrt(self.n_classes))
_indexes = ExtendedCollection.split_index_by_pred(_ncl, self.instances)
return [
ExtendedCollection(
self.instances[ind] if len(ind) > 0 else np.asarray([], dtype=int),
np.asarray(
[
ExClassManager.get_true(_ncl, lbl)
for lbl in (self.labels[ind] if len(ind) > 0 else [])
],
dtype=int,
),
classes=range(0, _ncl),
)
for ind in _indexes
]
@classmethod
def split_index_by_pred(
cls, n_classes: int, instances: np.ndarray
) -> List[np.ndarray]:
_pred_label = [np.argmax(inst[-n_classes:], axis=0) for inst in instances]
return [
np.asarray([j for (j, x) in enumerate(_pred_label) if x == i])
for i in range(0, n_classes)
]
@classmethod
def extend_instances(
cls, instances: np.ndarray, pred_proba: np.ndarray
) -> np.ndarray:
if isinstance(instances, sp.csr_matrix):
_pred_proba = sp.csr_matrix(pred_proba)
n_x = sp.hstack([instances, _pred_proba])
elif isinstance(instances, np.ndarray):
n_x = np.concatenate((instances, pred_proba), axis=1)
else:
raise ValueError("Unsupported matrix format")
return n_x
@classmethod
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any:
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]
n_x = cls.extend_instances(base.X, pred_proba)
# n_y = (exptected y, predicted y)
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
n_y = np.asarray(
[
ExClassManager.get_ex(n_classes, true_class, pred_class)
for (true_class, pred_class) in zip(base.y, pred)
]
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
def get_dataset(name):
datasets = {
"spambase": lambda: qp.datasets.fetch_UCIDataset(

View File

@ -1,11 +1,14 @@
from abc import abstractmethod
import math
import numpy as np
import scipy.sparse as sp
from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier
from quapy.method.aggregative import SLD
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from .data import ExtendedCollection
from quacc.data import ExtendedCollection as EC
def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
@ -15,60 +18,36 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev):
return estim_prev
def _get_ex_class(classes, true_class, pred_class):
return true_class * classes + pred_class
def _extend_instances(instances, pred_proba):
if isinstance(instances, sp.csr_matrix):
_pred_proba = sp.csr_matrix(pred_proba)
n_x = sp.hstack([instances, _pred_proba])
elif isinstance(instances, np.ndarray):
n_x = np.concatenate((instances, pred_proba), axis=1)
else:
raise ValueError("Unsupported matrix format")
return n_x
def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection:
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]
n_x = _extend_instances(base.X, pred_proba)
# n_y = (exptected y, predicted y)
pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba])
n_y = np.asarray(
[
_get_ex_class(n_classes, true_class, pred_class)
for (true_class, pred_class) in zip(base.y, pred)
]
)
return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)])
class AccuracyEstimator:
def __init__(self, model: BaseEstimator, q_model: BaseQuantifier):
self.model = model
self.q_model = q_model
self.e_train = None
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
def extend(self, base: LabelledCollection, pred_proba=None) -> EC:
if not pred_proba:
pred_proba = self.model.predict_proba(base.X)
return _extend_collection(base, pred_proba)
return EC.extend_collection(base, pred_proba)
def fit(self, train: LabelledCollection | ExtendedCollection):
@abstractmethod
def fit(self, train: LabelledCollection | EC):
...
@abstractmethod
def estimate(self, instances, ext=False):
...
class MulticlassAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
self.c_model = c_model
self.q_model = SLD(LogisticRegression())
self.e_train = None
def fit(self, train: LabelledCollection | EC):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.model, *train.Xy, method="predict_proba"
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = _extend_collection(train, pred_prob_train)
self.e_train = EC.extend_collection(train, pred_prob_train)
else:
self.e_train = train
@ -76,8 +55,8 @@ class AccuracyEstimator:
def estimate(self, instances, ext=False):
if not ext:
pred_prob = self.model.predict_proba(instances)
e_inst = _extend_instances(instances, pred_prob)
pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob)
else:
e_inst = instances
@ -86,3 +65,51 @@ class AccuracyEstimator:
return _check_prevalence_classes(
self.e_train.classes_, self.q_model.classes_, estim_prev
)
class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
def __init__(self, c_model: BaseEstimator):
self.c_model = c_model
self.q_model_0 = SLD(LogisticRegression())
self.q_model_1 = SLD(LogisticRegression())
self.e_train: EC = None
def fit(self, train: LabelledCollection | EC):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
pred_prob_train = cross_val_predict(
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = EC.extend_collection(train, pred_prob_train)
else:
self.e_train = train
[e_train_0, e_train_1] = self.e_train.split_by_pred()
self.q_model_0.fit(self.e_train_0)
self.q_model_1.fit(self.e_train_1)
def estimate(self, instances, ext=False):
# TODO: test
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob)
else:
e_inst = instances
_ncl = int(math.sqrt(self.e_train.n_classes))
[e_inst_0, e_inst_1] = [
e_inst[ind] for ind in EC.split_index_by_pred(_ncl, e_inst)
]
estim_prev_0 = self.q_model_0.quantify(e_inst_0)
estim_prev_1 = self.q_model_1.quantify(e_inst_1)
estim_prev = []
for prev_row in zip(estim_prev_0, estim_prev_1):
for prev in prev_row:
estim_prev.append(prev)
return estim_prev

View File

@ -1,13 +1,12 @@
import pandas as pd
import quapy as qp
from quapy.method.aggregative import SLD
from quapy.protocol import APP
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
import quacc.evaluation as eval
from quacc.estimator import AccuracyEstimator
from quacc.estimator import MulticlassAccuracyEstimator
from .data import get_dataset
from quacc.data import get_dataset
qp.environ["SAMPLE_SIZE"] = 100
@ -17,16 +16,17 @@ pd.set_option("display.float_format", "{:.4f}".format)
def test_2(dataset_name):
train, test = get_dataset(dataset_name)
model = SVC(probability=True)
model = LogisticRegression()
print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True)
model.fit(*train.Xy)
print("fit")
qmodel = SLD(SVC(probability=True))
estimator = AccuracyEstimator(model, qmodel)
estimator = MulticlassAccuracyEstimator(model)
print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True)
print(
f"fitting qmodel {estimator.q_model.__class__.__name__}...", end=" ", flush=True
)
estimator.fit(train)
print("fit")

0
tests/__init__.py Normal file
View File

94
tests/test_data.py Normal file
View File

@ -0,0 +1,94 @@
import pytest
from quacc.data import ExClassManager as ECM, ExtendedCollection
import numpy as np
class TestExClassManager:
@pytest.mark.parametrize(
"true_class,pred_class,result",
[
(0, 0, 0),
(0, 1, 1),
(1, 0, 2),
(1, 1, 3),
],
)
def test_get_ex(self, true_class, pred_class, result):
ncl = 2
assert ECM.get_ex(ncl, true_class, pred_class) == result
@pytest.mark.parametrize(
"ex_class,result",
[
(0, 0),
(1, 1),
(2, 0),
(3, 1),
],
)
def test_get_pred(self, ex_class, result):
ncl = 2
assert ECM.get_pred(ncl, ex_class) == result
@pytest.mark.parametrize(
"ex_class,result",
[
(0, 0),
(1, 0),
(2, 1),
(3, 1),
],
)
def test_get_true(self, ex_class, result):
ncl = 2
assert ECM.get_true(ncl, ex_class) == result
class TestExtendedCollection:
@pytest.mark.parametrize(
"instances,labels,inst0,lbl0,inst1,lbl1",
[
(
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]],
[3, 0, 1, 2],
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 1],
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[1, 0],
),
(
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[3, 1],
[],
[],
[[0, 0.3, 0.7], [2, 0.28, 0.72]],
[1, 0],
),
(
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 2],
[[1, 0.54, 0.46], [3, 0.6, 0.4]],
[0, 1],
[],
[],
),
],
)
def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1):
ec = ExtendedCollection(
np.asarray(instances), np.asarray(labels), classes=range(0, 4)
)
[ec0, ec1] = ec.split_by_pred()
print(ec0.X, np.asarray(inst0))
assert( np.array_equal(ec0.X, np.asarray(inst0)) )
print(ec0.y, np.asarray(lbl0))
assert( np.array_equal(ec0.y, np.asarray(lbl0)) )
print(ec1.X, np.asarray(inst1))
assert( np.array_equal(ec1.X, np.asarray(inst1)) )
print(ec1.y, np.asarray(lbl1))
assert( np.array_equal(ec1.y, np.asarray(lbl1)) )

4
tests/test_estimator.py Normal file
View File

@ -0,0 +1,4 @@
class TestBinaryQuantifierAccuracyEstimator:
def test_estimate(self):
pass