From 8a6579428bc04b0f147930c75d4d50b96b747dc1 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Wed, 8 Nov 2023 11:31:33 +0100 Subject: [PATCH] implementing the 'total' function of IFCB protocols --- examples/ifcb_experiments.py | 18 +++++++++++------- quapy/data/_ifcb.py | 19 ++++++++++++++++++- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/examples/ifcb_experiments.py b/examples/ifcb_experiments.py index fff13ef..913fdb8 100644 --- a/examples/ifcb_experiments.py +++ b/examples/ifcb_experiments.py @@ -6,14 +6,18 @@ from quapy.evaluation import evaluation_report def newLR(): return LogisticRegression(n_jobs=-1) -quantifiers = {'CC':qp.method.aggregative.CC(newLR()), - 'ACC':qp.method.aggregative.ACC(newLR()), - 'PCC':qp.method.aggregative.PCC(newLR()), - 'PACC':qp.method.aggregative.PACC(newLR()), - 'HDy':qp.method.aggregative.DistributionMatching(newLR()), - 'EMQ':qp.method.aggregative.EMQ(newLR())} -for quant_name, quantifier in quantifiers.items(): +quantifiers = [ + ('CC', qp.method.aggregative.CC(newLR())), + ('ACC', qp.method.aggregative.ACC(newLR())), + ('PCC', qp.method.aggregative.PCC(newLR())), + ('PACC', qp.method.aggregative.PACC(newLR())), + ('HDy', qp.method.aggregative.DistributionMatching(newLR())), + ('EMQ', qp.method.aggregative.EMQ(newLR())) +] + + +for quant_name, quantifier in quantifiers: print("Experiment with "+quant_name) train, test_gen = qp.datasets.fetch_IFCB() diff --git a/quapy/data/_ifcb.py b/quapy/data/_ifcb.py index 87bb030..4eb780d 100644 --- a/quapy/data/_ifcb.py +++ b/quapy/data/_ifcb.py @@ -20,6 +20,15 @@ class IFCBTrainSamplesFromDir(AbstractProtocol): y = s.iloc[:, 0].to_numpy() yield X, y + def total(self): + """ + Returns the total number of samples that the protocol generates. + + :return: The number of training samples to generate. + """ + return len(self.samples) + + class IFCBTestSamples(AbstractProtocol): def __init__(self, path_dir:str, test_prevalences_path: str): @@ -31,4 +40,12 @@ class IFCBTestSamples(AbstractProtocol): #Load the sample from disk X = pd.read_csv(os.path.join(self.path_dir,test_sample['sample']+'.csv')).to_numpy() prevalences = test_sample.iloc[1:].to_numpy().astype(float) - yield X, prevalences \ No newline at end of file + yield X, prevalences + + def total(self): + """ + Returns the total number of samples that the protocol generates. + + :return: The number of test samples to generate. + """ + return len(self.test_prevalences.index) \ No newline at end of file