1
0
Fork 0
QuaPy/quapy/classification/svmperf.py

107 lines
4.1 KiB
Python
Raw Normal View History

2020-12-03 16:59:13 +01:00
import random
import subprocess
import tempfile
from os import remove, makedirs
2021-01-15 18:32:32 +01:00
from os.path import join, exists
2020-12-03 16:59:13 +01:00
from subprocess import PIPE, STDOUT
import shutil
2021-01-15 18:32:32 +01:00
2020-12-03 16:59:13 +01:00
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.datasets import dump_svmlight_file
class SVMperf(BaseEstimator, ClassifierMixin):
# losses with their respective codes in svm_perf implementation
valid_losses = {'01':0, 'f1':1, 'kld':12, 'nkld':13, 'q':22, 'qacc':23, 'qf1':24, 'qgm':25, 'mae':26, 'mrae':27}
def __init__(self, svmperf_base, C=0.01, verbose=False, loss='01'):
2021-02-16 19:38:52 +01:00
assert exists(svmperf_base), f'path {svmperf_base} does not seem to point to a valid path'
2020-12-03 16:59:13 +01:00
self.svmperf_base = svmperf_base
self.C = C
self.verbose = verbose
self.loss = loss
def set_params(self, **parameters):
assert list(parameters.keys()) == ['C'], 'currently, only the C parameter is supported'
self.C = parameters['C']
2020-12-03 16:59:13 +01:00
def fit(self, X, y):
assert self.loss in SVMperf.valid_losses, \
f'unsupported loss {self.loss}, valid ones are {list(SVMperf.valid_losses.keys())}'
self.svmperf_learn = join(self.svmperf_base, 'svm_perf_learn')
self.svmperf_classify = join(self.svmperf_base, 'svm_perf_classify')
self.loss_cmd = '-w 3 -l ' + str(self.valid_losses[self.loss])
self.c_cmd = '-c ' + str(self.C)
2020-12-03 16:59:13 +01:00
self.classes_ = sorted(np.unique(y))
self.n_classes_ = len(self.classes_)
local_random = random.Random()
# this would allow to run parallel instances of predict
random_code = '-'.join(str(local_random.randint(0,1000000)) for _ in range(5))
# self.tmpdir = tempfile.TemporaryDirectory(suffix=random_code)
# tmp dir are removed after the fit terminates in multiprocessing... moving to regular directories + __del__
self.tmpdir = '.svmperf-' + random_code
makedirs(self.tmpdir, exist_ok=True)
2020-12-03 16:59:13 +01:00
# self.model = join(self.tmpdir.name, 'model-'+random_code)
# traindat = join(self.tmpdir.name, f'train-{random_code}.dat')
self.model = join(self.tmpdir, 'model-'+random_code)
traindat = join(self.tmpdir, f'train-{random_code}.dat')
2020-12-03 16:59:13 +01:00
dump_svmlight_file(X, y, traindat, zero_based=False)
cmd = ' '.join([self.svmperf_learn, self.c_cmd, self.loss_cmd, traindat, self.model])
2020-12-03 16:59:13 +01:00
if self.verbose:
print('[Running]', cmd)
p = subprocess.run(cmd.split(), stdout=PIPE, stderr=STDOUT)
remove(traindat)
if self.verbose:
print(p.stdout.decode('utf-8'))
return self
def predict(self, X):
2020-12-03 16:59:13 +01:00
confidence_scores = self.decision_function(X)
predictions = (confidence_scores > 0) * 1
return predictions
def decision_function(self, X, y=None):
assert hasattr(self, 'tmpdir'), 'predict called before fit'
assert self.tmpdir is not None, 'model directory corrupted'
assert exists(self.model), 'model not found'
if y is None:
y = np.zeros(X.shape[0])
# in order to allow for parallel runs of predict, a random code is assigned
local_random = random.Random()
random_code = '-'.join(str(local_random.randint(0, 1000000)) for _ in range(5))
# predictions_path = join(self.tmpdir.name, 'predictions'+random_code+'.dat')
# testdat = join(self.tmpdir.name, 'test'+random_code+'.dat')
predictions_path = join(self.tmpdir, 'predictions' + random_code + '.dat')
testdat = join(self.tmpdir, 'test' + random_code + '.dat')
2020-12-03 16:59:13 +01:00
dump_svmlight_file(X, y, testdat, zero_based=False)
cmd = ' '.join([self.svmperf_classify, testdat, self.model, predictions_path])
if self.verbose:
print('[Running]', cmd)
p = subprocess.run(cmd.split(), stdout=PIPE, stderr=STDOUT)
if self.verbose:
print(p.stdout.decode('utf-8'))
scores = np.loadtxt(predictions_path)
remove(testdat)
remove(predictions_path)
return scores
def __del__(self):
if hasattr(self, 'tmpdir'):
shutil.rmtree(self.tmpdir)
2020-12-03 16:59:13 +01:00