2021-01-22 09:58:12 +01:00
|
|
|
from sklearn.base import BaseEstimator
|
2021-01-06 14:58:29 +01:00
|
|
|
from sklearn.decomposition import TruncatedSVD
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
|
|
|
|
2021-01-22 09:58:12 +01:00
|
|
|
class PCALR(BaseEstimator):
|
2021-01-18 19:14:04 +01:00
|
|
|
"""
|
|
|
|
An example of a classification method that also generates embedded inputs, as those required for QuaNet.
|
|
|
|
This example simply combines a Principal Component Analysis (PCA) with Logistic Regression (LR).
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
|
2021-01-22 09:58:12 +01:00
|
|
|
def __init__(self, n_components=100, **kwargs):
|
2021-01-06 14:58:29 +01:00
|
|
|
self.n_components = n_components
|
2021-01-18 19:14:04 +01:00
|
|
|
self.learner = LogisticRegression(**kwargs)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
|
|
|
def get_params(self):
|
|
|
|
params = {'n_components': self.n_components}
|
|
|
|
params.update(self.learner.get_params())
|
|
|
|
return params
|
|
|
|
|
|
|
|
def set_params(self, **params):
|
|
|
|
if 'n_components' in params:
|
|
|
|
self.n_components = params['n_components']
|
|
|
|
del params['n_components']
|
|
|
|
self.learner.set_params(**params)
|
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
def fit(self, X, y):
|
2021-01-22 09:58:12 +01:00
|
|
|
self.learner.fit(X, y)
|
2021-06-11 10:52:30 +02:00
|
|
|
nF = X.shape[1]
|
|
|
|
self.pca = None
|
|
|
|
if nF > self.n_components:
|
|
|
|
self.pca = TruncatedSVD(self.n_components).fit(X, y)
|
2021-01-06 14:58:29 +01:00
|
|
|
self.classes_ = self.learner.classes_
|
|
|
|
return self
|
2021-06-15 07:49:16 +02:00
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
def predict(self, X):
|
2021-01-22 09:58:12 +01:00
|
|
|
# X = self.transform(X)
|
|
|
|
return self.learner.predict(X)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
def predict_proba(self, X):
|
2021-01-22 09:58:12 +01:00
|
|
|
# X = self.transform(X)
|
|
|
|
return self.learner.predict_proba(X)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
2021-01-18 19:14:04 +01:00
|
|
|
def transform(self, X):
|
2021-06-11 10:52:30 +02:00
|
|
|
if self.pca is None:
|
|
|
|
return X
|
2021-01-18 19:14:04 +01:00
|
|
|
return self.pca.transform(X)
|