forked from moreo/QuaPy
svm-perf leaks model, fixing...
This commit is contained in:
parent
9fd9d096f6
commit
e8c3e29911
|
@ -32,7 +32,7 @@ class PCALR(BaseEstimator):
|
||||||
self.pca = TruncatedSVD(self.n_components).fit(X, y)
|
self.pca = TruncatedSVD(self.n_components).fit(X, y)
|
||||||
self.classes_ = self.learner.classes_
|
self.classes_ = self.learner.classes_
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
# X = self.transform(X)
|
# X = self.transform(X)
|
||||||
|
|
|
@ -58,6 +58,8 @@ class SVMperf(BaseEstimator, ClassifierMixin):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('[Running]', cmd)
|
print('[Running]', cmd)
|
||||||
p = subprocess.run(cmd.split(), stdout=PIPE, stderr=STDOUT)
|
p = subprocess.run(cmd.split(), stdout=PIPE, stderr=STDOUT)
|
||||||
|
if not exists(self.model):
|
||||||
|
print(p.stderr.decode('utf-8'))
|
||||||
remove(traindat)
|
remove(traindat)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -102,5 +104,5 @@ class SVMperf(BaseEstimator, ClassifierMixin):
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, 'tmpdir'):
|
if hasattr(self, 'tmpdir'):
|
||||||
shutil.rmtree(self.tmpdir)
|
pass # shutil.rmtree(self.tmpdir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
|
@ -172,6 +172,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
# join all data, and keep the order in which the methods appeared for the first time
|
# join all data, and keep the order in which the methods appeared for the first time
|
||||||
data = defaultdict(lambda:{'x':np.empty(shape=(0)), 'y':np.empty(shape=(0))})
|
data = defaultdict(lambda:{'x':np.empty(shape=(0)), 'y':np.empty(shape=(0))})
|
||||||
method_order = []
|
method_order = []
|
||||||
|
|
||||||
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
|
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
|
||||||
tr_prev_i = np.repeat(tr_prev_i.reshape(1,-1), repeats=test_prevs_i.shape[0], axis=0)
|
tr_prev_i = np.repeat(tr_prev_i.reshape(1,-1), repeats=test_prevs_i.shape[0], axis=0)
|
||||||
|
|
||||||
|
@ -185,6 +186,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
method_order.append(method)
|
method_order.append(method)
|
||||||
|
|
||||||
bins = np.linspace(0, 1, n_bins+1)
|
bins = np.linspace(0, 1, n_bins+1)
|
||||||
|
inds_histogram_global = np.zeros(n_bins, dtype=np.float) # we use this to keep track of how many datapoits contribute to each bin
|
||||||
binwidth = 1 / n_bins
|
binwidth = 1 / n_bins
|
||||||
min_x, max_x = None, None
|
min_x, max_x = None, None
|
||||||
for method in method_order:
|
for method in method_order:
|
||||||
|
@ -194,6 +196,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
method_drifts=np.log(1+method_drifts)
|
method_drifts=np.log(1+method_drifts)
|
||||||
|
|
||||||
inds = np.digitize(tr_test_drifts, bins, right=True)
|
inds = np.digitize(tr_test_drifts, bins, right=True)
|
||||||
|
inds_histogram_global += np.histogram(tr_test_drifts, density=True, bins=bins)[0]
|
||||||
|
|
||||||
xs, ys, ystds = [], [], []
|
xs, ys, ystds = [], [], []
|
||||||
for ind in range(len(bins)):
|
for ind in range(len(bins)):
|
||||||
selected = inds==ind
|
selected = inds==ind
|
||||||
|
@ -214,6 +218,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
if show_std:
|
if show_std:
|
||||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
||||||
|
|
||||||
|
# xs = bins[:-1]
|
||||||
|
# ys = inds_histogram_global
|
||||||
|
# print(xs.shape, ys.shape)
|
||||||
|
# ax.errorbar(xs, ys, label='density')
|
||||||
|
|
||||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
||||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
||||||
title=title)
|
title=title)
|
||||||
|
|
Loading…
Reference in New Issue