diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..e69de29 diff --git a/.gitignore b/.gitignore index cd571b7..5766485 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ quavenv/* *.pdf quacc/__pycache__/* +tests/__pycache__/* \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..bbca6bd --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "main", + "type": "python", + "request": "launch", + "program": "C:\\Users\\Lorenzo Volpi\\source\\tesi\\quacc\\main.py", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 381c1c8..7988285 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "abstention" version = "0.1.3.1" description = "Functions for abstention, calibration and label shift domain adaptation" -category = "main" optional = false python-versions = "*" files = [ @@ -20,7 +19,6 @@ scipy = ">=1.1.0" name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -32,7 +30,6 @@ files = [ name = "contourpy" version = "1.0.7" description = "Python library for calculating contours of 2D quadrilateral grids" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -107,7 +104,6 @@ test-no-images = ["pytest"] name = "cycler" version = "0.11.0" description = "Composable style cycles" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -119,7 +115,6 @@ files = [ name = "fonttools" version = "4.39.4" description = "Tools to manipulate font files" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -141,11 +136,21 @@ ufo = ["fs (>=2.2.0,<3)"] unicode = ["unicodedata2 (>=15.0.0)"] woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -157,7 +162,6 @@ files = [ name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -235,7 +239,6 @@ files = [ name = "matplotlib" version = "3.7.1" description = "Python plotting package" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -297,7 +300,6 @@ python-dateutil = ">=2.7" name = "numpy" version = "1.24.3" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -335,7 +337,6 @@ files = [ name = "packaging" version = "23.1" description = "Core utilities for Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -345,37 +346,36 @@ files = [ [[package]] name = "pandas" -version = "2.0.1" +version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "pandas-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70a996a1d2432dadedbb638fe7d921c88b0cc4dd90374eab51bb33dc6c0c2a12"}, - {file = "pandas-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:909a72b52175590debbf1d0c9e3e6bce2f1833c80c76d80bd1aa09188be768e5"}, - {file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe7914d8ddb2d54b900cec264c090b88d141a1eed605c9539a187dbc2547f022"}, - {file = "pandas-2.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a514ae436b23a92366fbad8365807fc0eed15ca219690b3445dcfa33597a5cc"}, - {file = "pandas-2.0.1-cp310-cp310-win32.whl", hash = "sha256:12bd6618e3cc737c5200ecabbbb5eaba8ab645a4b0db508ceeb4004bb10b060e"}, - {file = "pandas-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:2b6fe5f7ce1cba0e74188c8473c9091ead9b293ef0a6794939f8cc7947057abd"}, - {file = "pandas-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:00959a04a1d7bbc63d75a768540fb20ecc9e65fd80744c930e23768345a362a7"}, - {file = "pandas-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af2449e9e984dfad39276b885271ba31c5e0204ffd9f21f287a245980b0e4091"}, - {file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:910df06feaf9935d05247db6de452f6d59820e432c18a2919a92ffcd98f8f79b"}, - {file = "pandas-2.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0067f2419f933101bdc6001bcea1d50812afbd367b30943417d67fbb99678"}, - {file = "pandas-2.0.1-cp311-cp311-win32.whl", hash = "sha256:7b8395d335b08bc8b050590da264f94a439b4770ff16bb51798527f1dd840388"}, - {file = "pandas-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:8db5a644d184a38e6ed40feeb12d410d7fcc36648443defe4707022da127fc35"}, - {file = "pandas-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7bbf173d364130334e0159a9a034f573e8b44a05320995127cf676b85fd8ce86"}, - {file = "pandas-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c0853d487b6c868bf107a4b270a823746175b1932093b537b9b76c639fc6f7e"}, - {file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25e23a03f7ad7211ffa30cb181c3e5f6d96a8e4cb22898af462a7333f8a74eb"}, - {file = "pandas-2.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e09a53a4fe8d6ae2149959a2d02e1ef2f4d2ceb285ac48f74b79798507e468b4"}, - {file = "pandas-2.0.1-cp38-cp38-win32.whl", hash = "sha256:a2564629b3a47b6aa303e024e3d84e850d36746f7e804347f64229f8c87416ea"}, - {file = "pandas-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:03e677c6bc9cfb7f93a8b617d44f6091613a5671ef2944818469be7b42114a00"}, - {file = "pandas-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d099ecaa5b9e977b55cd43cf842ec13b14afa1cfa51b7e1179d90b38c53ce6a"}, - {file = "pandas-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a37ee35a3eb6ce523b2c064af6286c45ea1c7ff882d46e10d0945dbda7572753"}, - {file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:320b180d125c3842c5da5889183b9a43da4ebba375ab2ef938f57bf267a3c684"}, - {file = "pandas-2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18d22cb9043b6c6804529810f492ab09d638ddf625c5dea8529239607295cb59"}, - {file = "pandas-2.0.1-cp39-cp39-win32.whl", hash = "sha256:90d1d365d77d287063c5e339f49b27bd99ef06d10a8843cf00b1a49326d492c1"}, - {file = "pandas-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:99f7192d8b0e6daf8e0d0fd93baa40056684e4b4aaaef9ea78dff34168e1f2f0"}, - {file = "pandas-2.0.1.tar.gz", hash = "sha256:19b8e5270da32b41ebf12f0e7165efa7024492e9513fb46fb631c5022ae5709d"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, + {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce0c6f76a0f1ba361551f3e6dceaff06bde7514a374aa43e33b588ec10420183"}, + {file = "pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba619e410a21d8c387a1ea6e8a0e49bb42216474436245718d7f2e88a2f8d7c0"}, + {file = "pandas-2.0.3-cp310-cp310-win32.whl", hash = "sha256:3ef285093b4fe5058eefd756100a367f27029913760773c8bf1d2d8bebe5d210"}, + {file = "pandas-2.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:9ee1a69328d5c36c98d8e74db06f4ad518a1840e8ccb94a4ba86920986bb617e"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b084b91d8d66ab19f5bb3256cbd5ea661848338301940e17f4492b2ce0801fe8"}, + {file = "pandas-2.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:37673e3bdf1551b95bf5d4ce372b37770f9529743d2498032439371fc7b7eb26"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9cb1e14fdb546396b7e1b923ffaeeac24e4cedd14266c3497216dd4448e4f2d"}, + {file = "pandas-2.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9cd88488cceb7635aebb84809d087468eb33551097d600c6dad13602029c2df"}, + {file = "pandas-2.0.3-cp311-cp311-win32.whl", hash = "sha256:694888a81198786f0e164ee3a581df7d505024fbb1f15202fc7db88a71d84ebd"}, + {file = "pandas-2.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:6a21ab5c89dcbd57f78d0ae16630b090eec626360085a4148693def5452d8a6b"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9e4da0d45e7f34c069fe4d522359df7d23badf83abc1d1cef398895822d11061"}, + {file = "pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:32fca2ee1b0d93dd71d979726b12b61faa06aeb93cf77468776287f41ff8fdc5"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:258d3624b3ae734490e4d63c430256e716f488c4fcb7c8e9bde2d3aa46c29089"}, + {file = "pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eae3dc34fa1aa7772dd3fc60270d13ced7346fcbcfee017d3132ec625e23bb0"}, + {file = "pandas-2.0.3-cp38-cp38-win32.whl", hash = "sha256:f3421a7afb1a43f7e38e82e844e2bca9a6d793d66c1a7f9f0ff39a795bbc5e02"}, + {file = "pandas-2.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:69d7f3884c95da3a31ef82b7618af5710dba95bb885ffab339aad925c3e8ce78"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5247fb1ba347c1261cbbf0fcfba4a3121fbb4029d95d9ef4dc45406620b25c8b"}, + {file = "pandas-2.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:81af086f4543c9d8bb128328b5d32e9986e0c84d3ee673a2ac6fb57fd14f755e"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1994c789bf12a7c5098277fb43836ce090f1073858c10f9220998ac74f37c69b"}, + {file = "pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ec591c48e29226bcbb316e0c1e9423622bc7a4eaf1ef7c3c9fa1a3981f89641"}, + {file = "pandas-2.0.3-cp39-cp39-win32.whl", hash = "sha256:04dbdbaf2e4d46ca8da896e1805bc04eb85caa9a82e259e8eed00254d5e0c682"}, + {file = "pandas-2.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:1168574b036cd8b93abc746171c9b4f1b83467438a5e45909fed645cf8692dbc"}, + {file = "pandas-2.0.3.tar.gz", hash = "sha256:c02f372a88e0d17f36d3093a644c73cfc1788e876a7c4bcb4020a77512e2043c"}, ] [package.dependencies] @@ -388,7 +388,7 @@ pytz = ">=2020.1" tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] +all = ["PyQt5 (>=5.15.1)", "SQLAlchemy (>=1.4.16)", "beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "s3fs (>=2021.08.0)", "scipy (>=1.7.1)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"] aws = ["s3fs (>=2021.08.0)"] clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"] compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"] @@ -407,14 +407,13 @@ plot = ["matplotlib (>=3.6.1)"] postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"] spss = ["pyreadstat (>=1.1.2)"] sql-other = ["SQLAlchemy (>=1.4.16)"] -test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] +test = ["hypothesis (>=6.34.2)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.6.3)"] [[package]] name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -490,11 +489,25 @@ files = [ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "pluggy" +version = "1.2.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, + {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -505,11 +518,30 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pytest" +version = "7.4.0" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, + {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -524,7 +556,6 @@ six = ">=1.5" name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" -category = "main" optional = false python-versions = "*" files = [ @@ -536,7 +567,6 @@ files = [ name = "quapy" version = "0.1.7" description = "QuaPy: a framework for Quantification in Python" -category = "main" optional = false python-versions = ">=3.6, <4" files = [ @@ -557,7 +587,6 @@ xlrd = "*" name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -600,7 +629,6 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.9.3" description = "Fundamental algorithms for scientific computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -639,7 +667,6 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -651,7 +678,6 @@ files = [ name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -663,7 +689,6 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -684,7 +709,6 @@ telegram = ["requests"] name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" files = [ @@ -696,7 +720,6 @@ files = [ name = "xlrd" version = "2.0.1" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -712,4 +735,4 @@ test = ["pytest", "pytest-cov"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "811aa60aea2ab4cf9b4bc4ad3e546e8ee1a81d78f15acd35f3f736b5f97512b4" +content-hash = "834ffb619893a1fb006e1b5a3213cc772117c9000e719b95a4478f74fd5d0066" diff --git a/pyproject.toml b/pyproject.toml index f488b16..a6297fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,11 +8,15 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.11" quapy = "^0.1.7" +pandas = "^2.0.3" [tool.poetry.scripts] main = "quacc.main:main" +[tool.poetry.group.dev.dependencies] +pytest = "^7.4.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/quacc/data.py b/quacc/data.py index 715802b..8f9b53c 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -1,11 +1,40 @@ -from typing import List, Optional +from typing import Any, List, Optional import numpy as np +import math import quapy as qp import scipy.sparse as sp from quapy.data import LabelledCollection +# Extended classes +# +# 0 ~ True 0 +# 1 ~ False 1 +# 2 ~ False 0 +# 3 ~ True 1 +# _____________________ +# | | | +# | True 0 | False 1 | +# |__________|__________| +# | | | +# | False 0 | True 1 | +# |__________|__________| +# +class ExClassManager: + @staticmethod + def get_ex(n_classes: int, true_class: int, pred_class: int) -> int: + return true_class * n_classes + pred_class + + @staticmethod + def get_pred(n_classes: int, ex_class: int) -> int: + return ex_class % n_classes + + @staticmethod + def get_true(n_classes: int, ex_class: int) -> int: + return ex_class // n_classes + + class ExtendedCollection(LabelledCollection): def __init__( self, @@ -14,7 +43,68 @@ class ExtendedCollection(LabelledCollection): classes: Optional[List] = None, ): super().__init__(instances, labels, classes=classes) - + + def split_by_pred(self): + _ncl = int(math.sqrt(self.n_classes)) + _indexes = ExtendedCollection.split_index_by_pred(_ncl, self.instances) + return [ + ExtendedCollection( + self.instances[ind] if len(ind) > 0 else np.asarray([], dtype=int), + np.asarray( + [ + ExClassManager.get_true(_ncl, lbl) + for lbl in (self.labels[ind] if len(ind) > 0 else []) + ], + dtype=int, + ), + classes=range(0, _ncl), + ) + for ind in _indexes + ] + + @classmethod + def split_index_by_pred( + cls, n_classes: int, instances: np.ndarray + ) -> List[np.ndarray]: + _pred_label = [np.argmax(inst[-n_classes:], axis=0) for inst in instances] + return [ + np.asarray([j for (j, x) in enumerate(_pred_label) if x == i]) + for i in range(0, n_classes) + ] + + @classmethod + def extend_instances( + cls, instances: np.ndarray, pred_proba: np.ndarray + ) -> np.ndarray: + if isinstance(instances, sp.csr_matrix): + _pred_proba = sp.csr_matrix(pred_proba) + n_x = sp.hstack([instances, _pred_proba]) + elif isinstance(instances, np.ndarray): + n_x = np.concatenate((instances, pred_proba), axis=1) + else: + raise ValueError("Unsupported matrix format") + + return n_x + + @classmethod + def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any: + n_classes = base.n_classes + + # n_X = [ X | predicted probs. ] + n_x = cls.extend_instances(base.X, pred_proba) + + # n_y = (exptected y, predicted y) + pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba]) + n_y = np.asarray( + [ + ExClassManager.get_ex(n_classes, true_class, pred_class) + for (true_class, pred_class) in zip(base.y, pred) + ] + ) + + return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)]) + + def get_dataset(name): datasets = { "spambase": lambda: qp.datasets.fetch_UCIDataset( diff --git a/quacc/estimator.py b/quacc/estimator.py index 4afd490..5152b36 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -1,11 +1,14 @@ +from abc import abstractmethod +import math + import numpy as np -import scipy.sparse as sp from quapy.data import LabelledCollection -from quapy.method.base import BaseQuantifier +from quapy.method.aggregative import SLD from sklearn.base import BaseEstimator +from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_predict -from .data import ExtendedCollection +from quacc.data import ExtendedCollection as EC def _check_prevalence_classes(true_classes, estim_classes, estim_prev): @@ -15,60 +18,36 @@ def _check_prevalence_classes(true_classes, estim_classes, estim_prev): return estim_prev -def _get_ex_class(classes, true_class, pred_class): - return true_class * classes + pred_class - - -def _extend_instances(instances, pred_proba): - if isinstance(instances, sp.csr_matrix): - _pred_proba = sp.csr_matrix(pred_proba) - n_x = sp.hstack([instances, _pred_proba]) - elif isinstance(instances, np.ndarray): - n_x = np.concatenate((instances, pred_proba), axis=1) - else: - raise ValueError("Unsupported matrix format") - - return n_x - - -def _extend_collection(base: LabelledCollection, pred_proba) -> ExtendedCollection: - n_classes = base.n_classes - - # n_X = [ X | predicted probs. ] - n_x = _extend_instances(base.X, pred_proba) - - # n_y = (exptected y, predicted y) - pred = np.asarray([prob.argmax(axis=0) for prob in pred_proba]) - n_y = np.asarray( - [ - _get_ex_class(n_classes, true_class, pred_class) - for (true_class, pred_class) in zip(base.y, pred) - ] - ) - - return ExtendedCollection(n_x, n_y, classes=[*range(0, n_classes * n_classes)]) - - class AccuracyEstimator: - def __init__(self, model: BaseEstimator, q_model: BaseQuantifier): - self.model = model - self.q_model = q_model - self.e_train = None - - def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: + def extend(self, base: LabelledCollection, pred_proba=None) -> EC: if not pred_proba: pred_proba = self.model.predict_proba(base.X) - return _extend_collection(base, pred_proba) + return EC.extend_collection(base, pred_proba) - def fit(self, train: LabelledCollection | ExtendedCollection): + @abstractmethod + def fit(self, train: LabelledCollection | EC): + ... + + @abstractmethod + def estimate(self, instances, ext=False): + ... + + +class MulticlassAccuracyEstimator(AccuracyEstimator): + def __init__(self, c_model: BaseEstimator): + self.c_model = c_model + self.q_model = SLD(LogisticRegression()) + self.e_train = None + + def fit(self, train: LabelledCollection | EC): # check if model is fit # self.model.fit(*train.Xy) if isinstance(train, LabelledCollection): pred_prob_train = cross_val_predict( - self.model, *train.Xy, method="predict_proba" + self.c_model, *train.Xy, method="predict_proba" ) - self.e_train = _extend_collection(train, pred_prob_train) + self.e_train = EC.extend_collection(train, pred_prob_train) else: self.e_train = train @@ -76,8 +55,8 @@ class AccuracyEstimator: def estimate(self, instances, ext=False): if not ext: - pred_prob = self.model.predict_proba(instances) - e_inst = _extend_instances(instances, pred_prob) + pred_prob = self.c_model.predict_proba(instances) + e_inst = EC.extend_instances(instances, pred_prob) else: e_inst = instances @@ -86,3 +65,51 @@ class AccuracyEstimator: return _check_prevalence_classes( self.e_train.classes_, self.q_model.classes_, estim_prev ) + + +class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): + def __init__(self, c_model: BaseEstimator): + self.c_model = c_model + self.q_model_0 = SLD(LogisticRegression()) + self.q_model_1 = SLD(LogisticRegression()) + self.e_train: EC = None + + def fit(self, train: LabelledCollection | EC): + # check if model is fit + # self.model.fit(*train.Xy) + if isinstance(train, LabelledCollection): + pred_prob_train = cross_val_predict( + self.c_model, *train.Xy, method="predict_proba" + ) + + self.e_train = EC.extend_collection(train, pred_prob_train) + else: + self.e_train = train + + [e_train_0, e_train_1] = self.e_train.split_by_pred() + + self.q_model_0.fit(self.e_train_0) + self.q_model_1.fit(self.e_train_1) + + def estimate(self, instances, ext=False): + # TODO: test + if not ext: + pred_prob = self.c_model.predict_proba(instances) + e_inst = EC.extend_instances(instances, pred_prob) + else: + e_inst = instances + + _ncl = int(math.sqrt(self.e_train.n_classes)) + [e_inst_0, e_inst_1] = [ + e_inst[ind] for ind in EC.split_index_by_pred(_ncl, e_inst) + ] + estim_prev_0 = self.q_model_0.quantify(e_inst_0) + estim_prev_1 = self.q_model_1.quantify(e_inst_1) + + estim_prev = [] + for prev_row in zip(estim_prev_0, estim_prev_1): + for prev in prev_row: + estim_prev.append(prev) + + return estim_prev + diff --git a/quacc/main.py b/quacc/main.py index 93af1d8..b251549 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -1,13 +1,12 @@ import pandas as pd import quapy as qp -from quapy.method.aggregative import SLD from quapy.protocol import APP -from sklearn.svm import SVC +from sklearn.linear_model import LogisticRegression import quacc.evaluation as eval -from quacc.estimator import AccuracyEstimator +from quacc.estimator import MulticlassAccuracyEstimator -from .data import get_dataset +from quacc.data import get_dataset qp.environ["SAMPLE_SIZE"] = 100 @@ -17,16 +16,17 @@ pd.set_option("display.float_format", "{:.4f}".format) def test_2(dataset_name): train, test = get_dataset(dataset_name) - model = SVC(probability=True) + model = LogisticRegression() print(f"fitting model {model.__class__.__name__}...", end=" ", flush=True) model.fit(*train.Xy) print("fit") - qmodel = SLD(SVC(probability=True)) - estimator = AccuracyEstimator(model, qmodel) + estimator = MulticlassAccuracyEstimator(model) - print(f"fitting qmodel {qmodel.__class__.__name__}...", end=" ", flush=True) + print( + f"fitting qmodel {estimator.q_model.__class__.__name__}...", end=" ", flush=True + ) estimator.fit(train) print("fit") diff --git a/quacc/test_1.py b/quacc/old_main.py similarity index 100% rename from quacc/test_1.py rename to quacc/old_main.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..8bc8c0f --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,94 @@ +import pytest +from quacc.data import ExClassManager as ECM, ExtendedCollection +import numpy as np + + +class TestExClassManager: + @pytest.mark.parametrize( + "true_class,pred_class,result", + [ + (0, 0, 0), + (0, 1, 1), + (1, 0, 2), + (1, 1, 3), + ], + ) + def test_get_ex(self, true_class, pred_class, result): + ncl = 2 + assert ECM.get_ex(ncl, true_class, pred_class) == result + + @pytest.mark.parametrize( + "ex_class,result", + [ + (0, 0), + (1, 1), + (2, 0), + (3, 1), + ], + ) + def test_get_pred(self, ex_class, result): + ncl = 2 + assert ECM.get_pred(ncl, ex_class) == result + + @pytest.mark.parametrize( + "ex_class,result", + [ + (0, 0), + (1, 0), + (2, 1), + (3, 1), + ], + ) + def test_get_true(self, ex_class, result): + ncl = 2 + assert ECM.get_true(ncl, ex_class) == result + + +class TestExtendedCollection: + @pytest.mark.parametrize( + "instances,labels,inst0,lbl0,inst1,lbl1", + [ + ( + [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]], + [3, 0, 1, 2], + [[1, 0.54, 0.46], [3, 0.6, 0.4]], + [0, 1], + [[0, 0.3, 0.7], [2, 0.28, 0.72]], + [1, 0], + ), + ( + [[0, 0.3, 0.7], [2, 0.28, 0.72]], + [3, 1], + [], + [], + [[0, 0.3, 0.7], [2, 0.28, 0.72]], + [1, 0], + ), + ( + [[1, 0.54, 0.46], [3, 0.6, 0.4]], + [0, 2], + [[1, 0.54, 0.46], [3, 0.6, 0.4]], + [0, 1], + [], + [], + ), + + ], + ) + def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1): + ec = ExtendedCollection( + np.asarray(instances), np.asarray(labels), classes=range(0, 4) + ) + [ec0, ec1] = ec.split_by_pred() + print(ec0.X, np.asarray(inst0)) + assert( np.array_equal(ec0.X, np.asarray(inst0)) ) + print(ec0.y, np.asarray(lbl0)) + assert( np.array_equal(ec0.y, np.asarray(lbl0)) ) + print(ec1.X, np.asarray(inst1)) + assert( np.array_equal(ec1.X, np.asarray(inst1)) ) + print(ec1.y, np.asarray(lbl1)) + assert( np.array_equal(ec1.y, np.asarray(lbl1)) ) + + + + diff --git a/tests/test_estimator.py b/tests/test_estimator.py new file mode 100644 index 0000000..190b0ba --- /dev/null +++ b/tests/test_estimator.py @@ -0,0 +1,4 @@ +class TestBinaryQuantifierAccuracyEstimator: + + def test_estimate(self): + pass \ No newline at end of file