Typing fixed, gitignore updated

This commit is contained in:
Lorenzo Volpi 2023-07-28 01:47:44 +02:00
parent 469dcb5898
commit dfd8d11e8f
3 changed files with 21 additions and 18 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ quavenv/*
*.pdf *.pdf
quacc/__pycache__/* quacc/__pycache__/*
tests/__pycache__/* tests/__pycache__/*
.coverage

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import List, Optional, Self
import numpy as np import numpy as np
import math import math
@ -44,7 +44,7 @@ class ExtendedCollection(LabelledCollection):
): ):
super().__init__(instances, labels, classes=classes) super().__init__(instances, labels, classes=classes)
def split_by_pred(self): def split_by_pred(self) -> List[Self]:
_ncl = int(math.sqrt(self.n_classes)) _ncl = int(math.sqrt(self.n_classes))
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances) _indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
if isinstance(self.instances, np.ndarray): if isinstance(self.instances, np.ndarray):
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
return n_x return n_x
@classmethod @classmethod
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any: def extend_collection(
cls, base: LabelledCollection, pred_proba: np.ndarray
) -> Self:
n_classes = base.n_classes n_classes = base.n_classes
# n_X = [ X | predicted probs. ] # n_X = [ X | predicted probs. ]

View File

@ -8,17 +8,17 @@ from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict from sklearn.model_selection import cross_val_predict
from quacc.data import ExtendedCollection as EC from quacc.data import ExtendedCollection
class AccuracyEstimator: class AccuracyEstimator:
def extend(self, base: LabelledCollection, pred_proba=None) -> EC: def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba: if not pred_proba:
pred_proba = self.c_model.predict_proba(base.X) pred_proba = self.c_model.predict_proba(base.X)
return EC.extend_collection(base, pred_proba) return ExtendedCollection.extend_collection(base, pred_proba)
@abstractmethod @abstractmethod
def fit(self, train: LabelledCollection | EC): def fit(self, train: LabelledCollection | ExtendedCollection):
... ...
@abstractmethod @abstractmethod
@ -32,7 +32,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
self.q_model = SLD(LogisticRegression()) self.q_model = SLD(LogisticRegression())
self.e_train = None self.e_train = None
def fit(self, train: LabelledCollection | EC): def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit # check if model is fit
# self.model.fit(*train.Xy) # self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection): if isinstance(train, LabelledCollection):
@ -40,7 +40,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
self.c_model, *train.Xy, method="predict_proba" self.c_model, *train.Xy, method="predict_proba"
) )
self.e_train = EC.extend_collection(train, pred_prob_train) self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
else: else:
self.e_train = train self.e_train = train
@ -49,7 +49,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
def estimate(self, instances, ext=False): def estimate(self, instances, ext=False):
if not ext: if not ext:
pred_prob = self.c_model.predict_proba(instances) pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob) e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else: else:
e_inst = instances e_inst = instances
@ -71,9 +71,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
self.c_model = c_model self.c_model = c_model
self.q_model_0 = SLD(LogisticRegression()) self.q_model_0 = SLD(LogisticRegression())
self.q_model_1 = SLD(LogisticRegression()) self.q_model_1 = SLD(LogisticRegression())
self.e_train: EC = None self.e_train = None
def fit(self, train: LabelledCollection | EC): def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit # check if model is fit
# self.model.fit(*train.Xy) # self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection): if isinstance(train, LabelledCollection):
@ -81,8 +81,8 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
self.c_model, *train.Xy, method="predict_proba" self.c_model, *train.Xy, method="predict_proba"
) )
self.e_train = EC.extend_collection(train, pred_prob_train) self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
else: elif isinstance(train, ExtendedCollection):
self.e_train = train self.e_train = train
self.n_classes = self.e_train.n_classes self.n_classes = self.e_train.n_classes
@ -95,12 +95,12 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
# TODO: test # TODO: test
if not ext: if not ext:
pred_prob = self.c_model.predict_proba(instances) pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob) e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else: else:
e_inst = instances e_inst = instances
_ncl = int(math.sqrt(self.n_classes)) _ncl = int(math.sqrt(self.n_classes))
s_inst, norms = EC.split_inst_by_pred(_ncl, e_inst) s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
[estim_prev_0, estim_prev_1] = [ [estim_prev_0, estim_prev_1] = [
self._quantify_helper(inst, norm, q_model) self._quantify_helper(inst, norm, q_model)
for (inst, norm, q_model) in zip( for (inst, norm, q_model) in zip(