63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
from sklearn.linear_model import LogisticRegression
|
|
import os
|
|
import sys
|
|
import pandas as pd
|
|
|
|
import quapy as qp
|
|
from method.aggregative import DistributionMatching
|
|
from distribution_matching.method_kdey import KDEy
|
|
from protocol import UPP
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
qp.environ['SAMPLE_SIZE'] = 100
|
|
qp.environ['N_JOBS'] = -1
|
|
method = 'KDE'
|
|
param = 0.1
|
|
div = 'topsoe'
|
|
method_identifier = f'{method}_{param}_{div}'
|
|
|
|
# generates tuples (dataset, method, method_name)
|
|
# (the dataset is needed for methods that process the dataset differently)
|
|
def gen_methods():
|
|
|
|
for dataset in qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST:
|
|
|
|
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True)
|
|
|
|
if method == 'KDE':
|
|
kdey = KDEy(LogisticRegression(), divergence=div, bandwidth=param, engine='sklearn')
|
|
yield data, kdey, method_identifier
|
|
|
|
elif method == 'DM':
|
|
dm = DistributionMatching(LogisticRegression(), divergence=div, nbins=param)
|
|
yield data, dm, method_identifier
|
|
|
|
else:
|
|
raise NotImplementedError('unknown method')
|
|
|
|
os.makedirs('results', exist_ok=True)
|
|
result_path = f'results/{method_identifier}.csv'
|
|
|
|
if os.path.exists(result_path):
|
|
print('Result already exit. Nothing to do')
|
|
sys.exit(0)
|
|
|
|
with open(result_path, 'wt') as csv:
|
|
csv.write(f'Method\tDataset\tMAE\tMRAE\n')
|
|
for data, quantifier, quant_name in gen_methods():
|
|
quantifier.fit(data.training)
|
|
protocol = UPP(data.test, repeats=100)
|
|
report = qp.evaluation.evaluation_report(quantifier, protocol, error_metrics=['mae', 'mrae'], verbose=True)
|
|
means = report.mean()
|
|
csv.write(f'{quant_name}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\n')
|
|
csv.flush()
|
|
|
|
df = pd.read_csv(result_path, sep='\t')
|
|
|
|
pd.set_option('display.max_columns', None)
|
|
pd.set_option('display.max_rows', None)
|
|
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
|
|
print(pv)
|