QuaPy/quapy/data/_ifcb.py

51 lines
1.7 KiB
Python

import os
import pandas as pd
from quapy.protocol import AbstractProtocol
class IFCBTrainSamplesFromDir(AbstractProtocol):
def __init__(self, path_dir:str, classes: list):
self.path_dir = path_dir
self.classes = classes
self.samples = []
for filename in os.listdir(path_dir):
if filename.endswith('.csv'):
self.samples.append(filename)
def __call__(self):
for sample in self.samples:
s = pd.read_csv(os.path.join(self.path_dir,sample))
# all columns but the first where we get the class
X = s.iloc[:, 1:].to_numpy()
y = s.iloc[:, 0].to_numpy()
yield X, y
def total(self):
"""
Returns the total number of samples that the protocol generates.
:return: The number of training samples to generate.
"""
return len(self.samples)
class IFCBTestSamples(AbstractProtocol):
def __init__(self, path_dir:str, test_prevalences_path: str):
self.path_dir = path_dir
self.test_prevalences = pd.read_csv(os.path.join(path_dir, test_prevalences_path))
def __call__(self):
for _, test_sample in self.test_prevalences.iterrows():
#Load the sample from disk
X = pd.read_csv(os.path.join(self.path_dir,test_sample['sample']+'.csv')).to_numpy()
prevalences = test_sample.iloc[1:].to_numpy().astype(float)
yield X, prevalences
def total(self):
"""
Returns the total number of samples that the protocol generates.
:return: The number of test samples to generate.
"""
return len(self.test_prevalences.index)