random_state updated

This commit is contained in:
Lorenzo Volpi 2023-11-26 16:42:35 +01:00
parent 1806243d53
commit 8f903d96e2
3 changed files with 10 additions and 7 deletions

View File

@ -17,6 +17,7 @@ import baselines.impweight as iw
import baselines.mandoline as mandolib import baselines.mandoline as mandolib
import baselines.rca as rcalib import baselines.rca as rcalib
from baselines.utils import clone_fit from baselines.utils import clone_fit
from quacc.environment import env
from .report import EvaluationReport from .report import EvaluationReport
@ -169,7 +170,7 @@ def doc(
predict_method="predict_proba", predict_method="predict_proba",
): ):
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0) val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
val1_probs = c_model_predict(val1.X) val1_probs = c_model_predict(val1.X)
val1_mc = np.max(val1_probs, axis=-1) val1_mc = np.max(val1_probs, axis=-1)
val1_preds = np.argmax(val1_probs, axis=-1) val1_preds = np.argmax(val1_probs, axis=-1)
@ -281,7 +282,7 @@ def rca_star(
"""elsahar19""" """elsahar19"""
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
validation1, validation2 = validation.split_stratified( validation1, validation2 = validation.split_stratified(
train_prop=0.5, random_state=0 train_prop=0.5, random_state=env._R_SEED
) )
val1_pred = c_model_predict(validation1.X) val1_pred = c_model_predict(validation1.X)
c_model1 = clone_fit(c_model, validation1.X, val1_pred) c_model1 = clone_fit(c_model, validation1.X, val1_pred)
@ -318,7 +319,7 @@ def gde(
predict_method="predict", predict_method="predict",
) -> EvaluationReport: ) -> EvaluationReport:
c_model_predict = getattr(c_model, predict_method) c_model_predict = getattr(c_model, predict_method)
val1, val2 = validation.split_stratified(train_prop=0.5, random_state=0) val1, val2 = validation.split_stratified(train_prop=0.5, random_state=env._R_SEED)
c_model1 = clone_fit(c_model, val1.X, val1.y) c_model1 = clone_fit(c_model, val1.X, val1.y)
c_model1_predict = getattr(c_model1, predict_method) c_model1_predict = getattr(c_model1, predict_method)
c_model2 = clone_fit(c_model, val2.X, val2.y) c_model2 = clone_fit(c_model, val2.X, val2.y)

View File

@ -8,6 +8,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC from sklearn.svm import LinearSVC
import quacc as qc import quacc as qc
from quacc.environment import env
from quacc.evaluation.report import EvaluationReport from quacc.evaluation.report import EvaluationReport
from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator from quacc.method.base import BQAE, MCAE, BaseAccuracyEstimator
from quacc.method.model_selection import GridSearchAE from quacc.method.model_selection import GridSearchAE
@ -97,7 +98,7 @@ class EvaluationMethodGridSearch(EvaluationMethod):
pg: str = "sld" pg: str = "sld"
def __call__(self, c_model, validation, protocol) -> EvaluationReport: def __call__(self, c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0) v_train, v_val = validation.split_stratified(0.6, random_state=env._R_SEED)
__grid = _param_grid.get(self.pg, {}) __grid = _param_grid.get(self.pg, {})
est = GridSearchAE( est = GridSearchAE(
model=self.get_est(c_model), model=self.get_est(c_model),
@ -122,7 +123,7 @@ def __sld_lr():
def __kde_lr(): def __kde_lr():
return KDEy(LogisticRegression()) return KDEy(LogisticRegression(), random_state=env._R_SEED)
def __sld_lsvc(): def __sld_lsvc():

View File

@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
import quacc as qc import quacc as qc
import quacc.error import quacc.error
from quacc.data import ExtendedCollection, ExtendedData from quacc.data import ExtendedCollection, ExtendedData
from quacc.environment import env
from quacc.evaluation import evaluate from quacc.evaluation import evaluate
from quacc.logger import SubLogger from quacc.logger import SubLogger
from quacc.method.base import ( from quacc.method.base import (
@ -251,7 +252,7 @@ class MCAEgsq(MultiClassAccuracyEstimator):
def fit(self, train: LabelledCollection): def fit(self, train: LabelledCollection):
self.e_train = self.extend(train) self.e_train = self.extend(train)
t_train, t_val = self.e_train.split_stratified(0.6, random_state=0) t_train, t_val = self.e_train.split_stratified(0.6, random_state=env._R_SEED)
self.quantifier = GridSearchQ( self.quantifier = GridSearchQ(
deepcopy(self.quantifier), deepcopy(self.quantifier),
param_grid=self.param_grid, param_grid=self.param_grid,
@ -304,7 +305,7 @@ class BQAEgsq(BinaryQuantifierAccuracyEstimator):
self.quantifiers = [] self.quantifiers = []
for e_train in self.e_trains: for e_train in self.e_trains:
t_train, t_val = e_train.split_stratified(0.6, random_state=0) t_train, t_val = e_train.split_stratified(0.6, random_state=env._R_SEED)
quantifier = GridSearchQ( quantifier = GridSearchQ(
model=deepcopy(self.quantifier), model=deepcopy(self.quantifier),
param_grid=self.param_grid, param_grid=self.param_grid,