dataset build prevs error fixed, conf updated

This commit is contained in:
Lorenzo Volpi 2024-03-27 16:18:02 +01:00
parent d237649be9
commit 6ff9e411c4
3 changed files with 9 additions and 2 deletions

View File

@ -427,6 +427,8 @@ multiclass_conf: &multiclass_conf
OUT_DIR_NAME: output/multiclass
DATASET_N_PREVS: 5
COMP_ESTIMATORS:
- bin_sld_lr_a
- mul_sld_lr_a
- bin_sld_lr_gs
- mul_sld_lr_gs
- bin_kde_lr_gs

View File

@ -5,7 +5,9 @@ DIRS=()
# DIRS+=("cc_lr")
# DIRS+=("baselines")
# DIRS+=("d_sld_rbf")
DIRS+=("d_sld_lr")
# DIRS+=("d_sld_lr")
# DIRS+=("debug")
DIRS+=("multiclass")
for dir in ${DIRS[@]}; do
scp -r andreaesuli@edge-nd1.isti.cnr.it:/home/andreaesuli/raid/lorenzo/output/${dir} ./output/

View File

@ -254,7 +254,10 @@ class Dataset(DatasetProvider):
dim = self.all_train.n_classes
lspace = np.linspace(0.0, 1.0, num=self._n_prevs + 1, endpoint=False)[1:]
mesh = np.array(np.meshgrid(*[lspace for _ in range(dim)])).T.reshape(-1, dim)
mesh = mesh[np.where(mesh.sum(axis=1) == 1.0)]
mesh = np.around(mesh, decimals=4)
mesh[np.where(np.around(mesh.sum(axis=1), decimals=4) == 0.9999), -1] += 0.0001
mesh[np.where(np.around(mesh.sum(axis=1), decimals=4) == 1.0001), -1] -= 0.0001
mesh = mesh[np.where(np.around(mesh.sum(axis=1), decimals=4) == 1.0)]
return np.around(np.unique(mesh, axis=0), decimals=4)
def __build_sample(