everything working, last results
This commit is contained in:
parent
96758834f3
commit
f72a612011
|
@ -6,7 +6,7 @@ import sys
|
|||
import pandas as pd
|
||||
|
||||
import quapy as qp
|
||||
from quapy.method.aggregative import EMQ, DistributionMatching, PACC, HDy, OneVsAllAggregative
|
||||
from quapy.method.aggregative import EMQ, DistributionMatching, PACC, HDy, OneVsAllAggregative, ACC
|
||||
from method_kdey import KDEy
|
||||
from method_dirichlety import DIRy
|
||||
from quapy.model_selection import GridSearchQ
|
||||
|
@ -17,8 +17,8 @@ if __name__ == '__main__':
|
|||
|
||||
qp.environ['SAMPLE_SIZE'] = qp.datasets.LEQUA2022_SAMPLE_SIZE['T1B']
|
||||
qp.environ['N_JOBS'] = -1
|
||||
result_dir = f'results_lequa'
|
||||
optim = 'mae'
|
||||
optim = 'mrae'
|
||||
result_dir = f'results_lequa_{optim}'
|
||||
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
|
@ -27,56 +27,52 @@ if __name__ == '__main__':
|
|||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
|
||||
for method in ['KDE', 'PACC', 'SLD', 'DM', 'HDy-OvA', 'DIR']:
|
||||
for method in ['DIR']:#'HDy-OvA', 'SLD', 'ACC-tv', 'PACC-tv']: #['DM', 'DIR']: #'KDEy-MLE', 'KDE-DM', 'DM', 'DIR']:
|
||||
|
||||
#if os.path.exists(result_path):
|
||||
# print('Result already exit. Nothing to do')
|
||||
# sys.exit(0)
|
||||
print('Init method', method)
|
||||
|
||||
result_path = f'{result_dir}/{method}'
|
||||
if os.path.exists(result_path+'.dataframe'):
|
||||
print(f'result file {result_path} already exist; skipping')
|
||||
continue
|
||||
|
||||
if os.path.exists(result_path+'.csv'):
|
||||
print(f'file {result_path}.csv already exist; skipping')
|
||||
continue
|
||||
|
||||
with open(result_path+'.csv', 'at') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
|
||||
with open(result_path+'.csv', 'wt') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
|
||||
|
||||
dataset = 'T1B'
|
||||
train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset)
|
||||
print(f'init {dataset} #instances: {len(train)}')
|
||||
if method == 'KDE':
|
||||
param_grid = {
|
||||
'bandwidth': np.linspace(0.001, 0.2, 21),
|
||||
'classifier__C': np.logspace(-4,4,9),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
|
||||
elif method == 'KDE-debug':
|
||||
param_grid = None
|
||||
qp.environ['N_JOBS'] = 1
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood', bandwidth=0.02)
|
||||
#train = train.sampling(280, *[1./train.n_classes]*(train.n_classes-1))
|
||||
if method == 'KDEy-MLE':
|
||||
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10)
|
||||
elif method in ['KDE-DM']:
|
||||
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = KDEy(LogisticRegression(), target='min_divergence', divergence='l2', montecarlo_trials=5000, val_split=10)
|
||||
elif method == 'DIR':
|
||||
param_grid = hyper_LR
|
||||
quantifier = DIRy(LogisticRegression())
|
||||
elif method == 'SLD':
|
||||
param_grid = hyper_LR
|
||||
quantifier = EMQ(LogisticRegression())
|
||||
elif method == 'PACC':
|
||||
elif method == 'PACC-tv':
|
||||
param_grid = hyper_LR
|
||||
quantifier = PACC(LogisticRegression())
|
||||
elif method == 'ACC-tv':
|
||||
param_grid = hyper_LR
|
||||
quantifier = ACC(LogisticRegression())
|
||||
elif method == 'HDy-OvA':
|
||||
param_grid = {
|
||||
'binary_quantifier__classifier__C': np.logspace(-3,3,9),
|
||||
'binary_quantifier__classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
param_grid = {'binary_quantifier__' + key: val for key, val in hyper_LR.items()}
|
||||
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
|
||||
elif method == 'DM':
|
||||
param_grid = {
|
||||
'nbins': [5,10,15],
|
||||
'classifier__C': np.logspace(-4,4,9),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
method_params = {
|
||||
'nbins': [4,8,16,32],
|
||||
'val_split': [10, 0.4],
|
||||
'divergence': ['HD', 'topsoe', 'l2']
|
||||
}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = DistributionMatching(LogisticRegression())
|
||||
else:
|
||||
raise NotImplementedError('unknown method', method)
|
||||
|
@ -86,7 +82,10 @@ if __name__ == '__main__':
|
|||
|
||||
modsel.fit(train)
|
||||
print(f'best params {modsel.best_params_}')
|
||||
pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
print(f'best score {modsel.best_score_}')
|
||||
pickle.dump(
|
||||
(modsel.best_params_, modsel.best_score_,),
|
||||
open(f'{result_path}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
quantifier = modsel.best_model()
|
||||
else:
|
|
@ -21,6 +21,8 @@ import dirichlet
|
|||
|
||||
class DIRy(AggregativeProbabilisticQuantifier):
|
||||
|
||||
MAXITER = 10000
|
||||
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, n_jobs=None, target='max_likelihood'):
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
|
@ -36,7 +38,7 @@ class DIRy(AggregativeProbabilisticQuantifier):
|
|||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
self.val_parameters = [dirichlet.mle(posteriors[y == cat]) for cat in range(data.n_classes)]
|
||||
self.val_parameters = [dirichlet.mle(posteriors[y == cat], maxiter=DIRy.MAXITER) for cat in range(data.n_classes)]
|
||||
|
||||
return self
|
||||
|
||||
|
@ -68,9 +70,6 @@ class DIRy(AggregativeProbabilisticQuantifier):
|
|||
def match(prev):
|
||||
val_pdf = self.val_pdf(prev)
|
||||
val_likelihood = val_pdf(posteriors)
|
||||
|
||||
#for i,prev_i in enumerate(prev):
|
||||
|
||||
return divergence(val_likelihood, test_likelihood)
|
||||
|
||||
# the initial point is set as the uniform distribution
|
|
@ -22,15 +22,14 @@ from statsmodels.nonparametric.kernel_density import KDEMultivariateConditional
|
|||
# TODO: think of a MMD-y variant, i.e., a MMD variant that uses the points in the simplex and possibly any non-linear kernel
|
||||
|
||||
|
||||
|
||||
class KDEy(AggregativeProbabilisticQuantifier):
|
||||
|
||||
BANDWIDTH_METHOD = ['auto', 'scott', 'silverman']
|
||||
ENGINE = ['scipy', 'sklearn', 'statsmodels']
|
||||
TARGET = ['min_divergence', 'max_likelihood']
|
||||
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, divergence: Union[str, Callable]='HD',
|
||||
bandwidth='scott', engine='sklearn', target='min_divergence', n_jobs=None, random_state=0):
|
||||
def __init__(self, classifier: BaseEstimator, val_split=0.4, divergence: Union[str, Callable]='L2',
|
||||
bandwidth='scott', engine='sklearn', target='min_divergence', n_jobs=None, random_state=0, montecarlo_trials=1000):
|
||||
assert bandwidth in KDEy.BANDWIDTH_METHOD or isinstance(bandwidth, float), \
|
||||
f'unknown bandwidth_method, valid ones are {KDEy.BANDWIDTH_METHOD}'
|
||||
assert engine in KDEy.ENGINE, f'unknown engine, valid ones are {KDEy.ENGINE}'
|
||||
|
@ -43,6 +42,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
|
|||
self.target = target
|
||||
self.n_jobs = n_jobs
|
||||
self.random_state=random_state
|
||||
self.montecarlo_trials = montecarlo_trials
|
||||
|
||||
def search_bandwidth_maxlikelihood(self, posteriors, labels):
|
||||
grid = {'bandwidth': np.linspace(0.001, 0.2, 100)}
|
||||
|
@ -123,6 +123,10 @@ class KDEy(AggregativeProbabilisticQuantifier):
|
|||
self.val_densities = [self.get_kde(posteriors[y == cat]) for cat in range(data.n_classes)]
|
||||
self.val_posteriors = posteriors
|
||||
|
||||
if self.target == 'min_divergence':
|
||||
self.samples = qp.functional.uniform_prevalence_sampling(n_classes=data.n_classes, size=self.montecarlo_trials)
|
||||
self.sample_densities = [self.pdf(kde_i, self.samples) for kde_i in self.val_densities]
|
||||
|
||||
return self
|
||||
|
||||
#def val_pdf(self, prev):
|
||||
|
@ -166,20 +170,17 @@ class KDEy(AggregativeProbabilisticQuantifier):
|
|||
r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
|
||||
return r.x
|
||||
|
||||
def _target_divergence(self, posteriors, montecarlo_samples=5000):
|
||||
def _target_divergence(self, posteriors):
|
||||
# in this variant we evaluate the divergence using a Montecarlo approach
|
||||
n_classes = len(self.val_densities)
|
||||
samples = qp.functional.uniform_prevalence_sampling(n_classes, size=montecarlo_samples)
|
||||
|
||||
test_kde = self.get_kde(posteriors)
|
||||
test_likelihood = self.pdf(test_kde, samples)
|
||||
test_likelihood = self.pdf(test_kde, self.samples)
|
||||
|
||||
divergence = _get_divergence(self.divergence)
|
||||
|
||||
sample_densities = [self.pdf(kde_i, samples) for kde_i in self.val_densities]
|
||||
|
||||
def match(prev):
|
||||
val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, sample_densities))
|
||||
val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, self.sample_densities))
|
||||
return divergence(val_likelihood, test_likelihood)
|
||||
|
||||
# the initial point is set as the uniform distribution
|
|
@ -2,8 +2,8 @@ import sys
|
|||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
result_dir = 'results_tweet_1000'
|
||||
#result_dir = 'results_lequa'
|
||||
result_dir = 'results_tweet_1000_mrae'
|
||||
#result_dir = 'results_lequa_mrae'
|
||||
|
||||
dfs = []
|
||||
|
|
@ -11,7 +11,9 @@ a maximizar.
|
|||
- evaluar un APP sobre el simplexo?
|
||||
- evaluar un UPP sobre el simplexo? (=Montecarlo)
|
||||
- qué divergencias? HD, topsoe, L1?
|
||||
- tampoco estoy evaluando en modo kfcv creo...
|
||||
|
||||
1) sacar lequa-kfcv y todos los kfcv que puedan tener sentido en tweets
|
||||
2) implementar el auto
|
||||
- optimización interna para likelihood [ninguno parece funcionar bien]
|
||||
- de todo (e.g., todo el training)?
|
||||
|
@ -29,3 +31,19 @@ a maximizar.
|
|||
11) KDEx?
|
||||
12) Dirichlet (el método DIR) habría que arreglarlo y mostrar resultados...
|
||||
13) Test estadisticos.
|
||||
|
||||
Notas:
|
||||
estoy probando a reemplazar el target max_likelihood con un min_divergence:
|
||||
- como la divergencia entre dos KDEs ahora es en el espacio continuo, no es facil como obtener. Estoy probando
|
||||
con una evaluación en test, pero el problema es que es overconfident con respecto a la que ha sido obtenida en test.
|
||||
Otra opción es un MonteCarlo que es lo que estoy probando ahora. Para este experimento he quitado la model selection
|
||||
del clasificador, y estoy dejando solo la que hace con el bandwidth por agilizarlo. Los resultados KDE-nomonte son un
|
||||
max_likelihood en igualdad de condiciones (solo bandwidth), KDE-monte1 es un montecarlo con HD a 1000 puntos, y KDE-monte2
|
||||
es lo mismo pero con 5000 puntos; ambos funcionan mal. KDE-monte1 y KDE-monte2 los voy a borrar.
|
||||
Ahora estoy probando con KDE-monte3, lo mismo pero con una L2 como
|
||||
divergencia. Parece mucho más parecido a KDE-nomonte (pero sigue siendo algo peor)
|
||||
- probar con más puntos (KDE-monte4 es a 5000 puntos)
|
||||
- habría que probar con topsoe (KDE-monte5)
|
||||
- probar con optimización del LR (KDE-monte6 y con kfcv)
|
||||
- probar con L1 en vez de L2 (KDE-monte7 con 5000 puntos y sin LR)
|
||||
- tal vez habría que probar con la L2, que funciona bien, en el min_divergence que evaluaba en test, o test+train
|
|
@ -20,8 +20,8 @@ if __name__ == '__main__':
|
|||
qp.environ['N_JOBS'] = -1
|
||||
n_bags_val = 250
|
||||
n_bags_test = 1000
|
||||
result_dir = f'results_tweet_{n_bags_test}'
|
||||
optim = 'mae'
|
||||
optim = 'mrae'
|
||||
result_dir = f'results_tweet_{optim}'
|
||||
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
|
@ -30,26 +30,28 @@ if __name__ == '__main__':
|
|||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
|
||||
for method in ['KDE-nomonte', 'KDE-monte2', 'SLD', 'KDE-kfcv']:# , 'DIR', 'DM', 'HDy-OvA', 'CC', 'ACC', 'PCC']:
|
||||
for method in ['CC', 'SLD', 'PCC', 'PACC-tv', 'ACC-tv', 'DM', 'HDy-OvA', 'KDEy-MLE', 'KDE-DM', 'DIR']:
|
||||
|
||||
#if os.path.exists(result_path):
|
||||
# print('Result already exit. Nothing to do')
|
||||
# sys.exit(0)
|
||||
print('Init method', method)
|
||||
|
||||
result_path = f'{result_dir}/{method}'
|
||||
if os.path.exists(result_path+'.dataframe'):
|
||||
print(f'result file {result_path} already exist; skipping')
|
||||
continue
|
||||
|
||||
with open(result_path+'.csv', 'at') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
|
||||
global_result_path = f'{result_dir}/{method}'
|
||||
|
||||
if not os.path.exists(global_result_path+'.csv'):
|
||||
with open(global_result_path+'.csv', 'wt') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
|
||||
|
||||
with open(global_result_path+'.csv', 'at') as csv:
|
||||
# four semeval dataset share the training, so it is useless to optimize hyperparameters four times;
|
||||
# this variable controls that the mod sel has already been done, and skip this otherwise
|
||||
semeval_trained = False
|
||||
|
||||
for dataset in qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST:
|
||||
print('init', dataset)
|
||||
|
||||
local_result_path = global_result_path + '_' + dataset
|
||||
if os.path.exists(local_result_path+'.dataframe'):
|
||||
print(f'result file {local_result_path}.dataframe already exist; skipping')
|
||||
continue
|
||||
|
||||
with qp.util.temp_seed(SEED):
|
||||
|
||||
|
@ -57,63 +59,64 @@ if __name__ == '__main__':
|
|||
|
||||
if not is_semeval or not semeval_trained:
|
||||
|
||||
if method == 'KDE':
|
||||
param_grid = {
|
||||
'bandwidth': np.linspace(0.001, 0.2, 21),
|
||||
'classifier__C': np.logspace(-4,4,9),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
if method == 'KDE': # not used
|
||||
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
|
||||
elif method == 'KDE-kfcv':
|
||||
param_grid = {
|
||||
'bandwidth': np.linspace(0.001, 0.2, 21),
|
||||
'classifier__C': np.logspace(-4,4,9),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
elif method == 'KDEy-MLE':
|
||||
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10)
|
||||
elif method in ['KDE-monte2']:
|
||||
param_grid = {
|
||||
'bandwidth': np.linspace(0.001, 0.2, 21),
|
||||
}
|
||||
quantifier = KDEy(LogisticRegression(), target='min_divergence')
|
||||
elif method in ['KDE-nomonte']:
|
||||
param_grid = {
|
||||
'bandwidth': np.linspace(0.001, 0.2, 21),
|
||||
}
|
||||
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
|
||||
elif method in ['KDE-DM']:
|
||||
method_params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = KDEy(LogisticRegression(), target='min_divergence', divergence='l2', montecarlo_trials=5000, val_split=10)
|
||||
elif method == 'DIR':
|
||||
param_grid = hyper_LR
|
||||
quantifier = DIRy(LogisticRegression())
|
||||
elif method == 'SLD':
|
||||
param_grid = hyper_LR
|
||||
quantifier = EMQ(LogisticRegression())
|
||||
elif method == 'PACC-tv':
|
||||
param_grid = hyper_LR
|
||||
quantifier = PACC(LogisticRegression())
|
||||
#elif method == 'PACC-kfcv':
|
||||
# param_grid = hyper_LR
|
||||
# quantifier = PACC(LogisticRegression(), val_split=10)
|
||||
elif method == 'PACC':
|
||||
param_grid = hyper_LR
|
||||
method_params = {'val_split': [10, 0.4]}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = PACC(LogisticRegression())
|
||||
elif method == 'PACC-kfcv':
|
||||
param_grid = hyper_LR
|
||||
quantifier = PACC(LogisticRegression(), val_split=10)
|
||||
elif method == 'ACC':
|
||||
method_params = {'val_split': [10, 0.4]}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = ACC(LogisticRegression())
|
||||
elif method == 'PCC':
|
||||
param_grid = hyper_LR
|
||||
quantifier = PCC(LogisticRegression())
|
||||
elif method == 'ACC':
|
||||
elif method == 'ACC-tv':
|
||||
param_grid = hyper_LR
|
||||
quantifier = ACC(LogisticRegression())
|
||||
elif method == 'CC':
|
||||
param_grid = hyper_LR
|
||||
quantifier = CC(LogisticRegression())
|
||||
elif method == 'HDy-OvA':
|
||||
param_grid = {
|
||||
'binary_quantifier__classifier__C': np.logspace(-4,4,9),
|
||||
'binary_quantifier__classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
param_grid = {'binary_quantifier__'+key:val for key,val in hyper_LR.items()}
|
||||
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
|
||||
#elif method == 'DM':
|
||||
# param_grid = {
|
||||
# 'nbins': [5,10,15],
|
||||
# 'classifier__C': np.logspace(-4,4,9),
|
||||
# 'classifier__class_weight': ['balanced', None]
|
||||
# }
|
||||
# quantifier = DistributionMatching(LogisticRegression())
|
||||
elif method == 'DM':
|
||||
param_grid = {
|
||||
'nbins': [5,10,15],
|
||||
'classifier__C': np.logspace(-4,4,9),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
method_params = {
|
||||
'nbins': [4,8,16,32],
|
||||
'val_split': [10, 0.4],
|
||||
'divergence': ['HD', 'topsoe', 'l2']
|
||||
}
|
||||
param_grid = {**method_params, **hyper_LR}
|
||||
quantifier = DistributionMatching(LogisticRegression())
|
||||
else:
|
||||
raise NotImplementedError('unknown method', method)
|
||||
|
@ -126,7 +129,10 @@ if __name__ == '__main__':
|
|||
|
||||
modsel.fit(data.training)
|
||||
print(f'best params {modsel.best_params_}')
|
||||
pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
print(f'best score {modsel.best_score_}')
|
||||
pickle.dump(
|
||||
(modsel.best_params_, modsel.best_score_,),
|
||||
open(f'{local_result_path}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
quantifier = modsel.best_model()
|
||||
|
||||
|
@ -140,12 +146,12 @@ if __name__ == '__main__':
|
|||
quantifier.fit(data.training)
|
||||
protocol = UPP(data.test, repeats=n_bags_test)
|
||||
report = qp.evaluation.evaluation_report(quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'], verbose=True)
|
||||
report.to_csv(result_path+'.dataframe')
|
||||
report.to_csv(f'{local_result_path}.dataframe')
|
||||
means = report.mean()
|
||||
csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
|
||||
csv.flush()
|
||||
|
||||
df = pd.read_csv(result_path+'.csv', sep='\t')
|
||||
df = pd.read_csv(global_result_path+'.csv', sep='\t')
|
||||
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.max_rows', None)
|
|
@ -5,7 +5,7 @@ import pandas as pd
|
|||
|
||||
import quapy as qp
|
||||
from method.aggregative import DistributionMatching
|
||||
from method_kdey import KDEy
|
||||
from distribution_matching.method_kdey import KDEy
|
||||
from protocol import UPP
|
||||
|
||||
|
||||
|
|
|
@ -608,6 +608,10 @@ def _get_divergence(divergence: Union[str, Callable]):
|
|||
return F.HellingerDistance
|
||||
elif divergence=='topsoe':
|
||||
return F.TopsoeDistance
|
||||
elif divergence.lower()=='l2':
|
||||
return lambda a,b: np.linalg.norm(a-b)
|
||||
elif divergence.lower()=='l1':
|
||||
return lambda a,b: np.linalg.norm(a-b, ord=1)
|
||||
else:
|
||||
raise ValueError(f'unknown divergence {divergence}')
|
||||
elif callable(divergence):
|
||||
|
|
Loading…
Reference in New Issue