2022-05-25 19:14:33 +02:00
|
|
|
import unittest
|
|
|
|
import quapy as qp
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
from time import time
|
2022-06-01 18:28:59 +02:00
|
|
|
from quapy.method.aggregative import EMQ
|
|
|
|
from quapy.method.base import BaseQuantifier
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
class EvalTestCase(unittest.TestCase):
|
|
|
|
def test_eval_speedup(self):
|
|
|
|
|
|
|
|
data = qp.datasets.fetch_reviews('hp', tfidf=True, min_df=10, pickle=True)
|
|
|
|
train, test = data.training, data.test
|
|
|
|
|
2022-06-21 10:27:06 +02:00
|
|
|
protocol = qp.protocol.APP(test, sample_size=1000, n_prevalences=11, repeats=1, random_state=1)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
class SlowLR(LogisticRegression):
|
|
|
|
def predict_proba(self, X):
|
|
|
|
import time
|
|
|
|
time.sleep(1)
|
|
|
|
return super().predict_proba(X)
|
|
|
|
|
|
|
|
emq = EMQ(SlowLR()).fit(train)
|
|
|
|
|
|
|
|
tinit = time()
|
2022-06-01 18:28:59 +02:00
|
|
|
score = qp.evaluation.evaluate(emq, protocol, error_metric='mae', verbose=True, aggr_speedup='force')
|
2022-05-25 19:14:33 +02:00
|
|
|
tend_optim = time()-tinit
|
|
|
|
print(f'evaluation (with optimization) took {tend_optim}s [MAE={score:.4f}]')
|
|
|
|
|
|
|
|
class NonAggregativeEMQ(BaseQuantifier):
|
|
|
|
|
|
|
|
def __init__(self, cls):
|
|
|
|
self.emq = EMQ(cls)
|
|
|
|
|
|
|
|
def quantify(self, instances):
|
|
|
|
return self.emq.quantify(instances)
|
|
|
|
|
|
|
|
def fit(self, data):
|
|
|
|
self.emq.fit(data)
|
|
|
|
return self
|
|
|
|
|
|
|
|
emq = NonAggregativeEMQ(SlowLR()).fit(train)
|
|
|
|
|
|
|
|
tinit = time()
|
|
|
|
score = qp.evaluation.evaluate(emq, protocol, error_metric='mae', verbose=True)
|
|
|
|
tend_no_optim = time() - tinit
|
|
|
|
print(f'evaluation (w/o optimization) took {tend_no_optim}s [MAE={score:.4f}]')
|
|
|
|
|
2022-06-01 18:28:59 +02:00
|
|
|
self.assertEqual(tend_no_optim>(tend_optim/2), True)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|