1
0
Fork 0
QuaPy/LeQua2022/baselines_T1B.py

56 lines
1.9 KiB
Python
Raw Normal View History

2021-10-13 20:36:53 +02:00
import pickle
import numpy as np
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import pandas as pd
2021-10-13 20:36:53 +02:00
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import *
import quapy.functional as F
from data import *
2021-10-13 20:36:53 +02:00
import os
import constants
2021-10-13 20:36:53 +02:00
predictions_path = os.path.join('predictions', 'T1B') # multiclass - vector
os.makedirs(predictions_path, exist_ok=True)
2021-10-13 20:36:53 +02:00
pathT1B = './data/T1B/public'
T1B_devvectors_path = os.path.join(pathT1B, 'dev_vectors')
T1B_devprevalence_path = os.path.join(pathT1B, 'dev_prevalences.csv')
T1B_trainpath = os.path.join(pathT1B, 'training_vectors.txt')
T1B_catmap = os.path.join(pathT1B, 'training_vectors_label_map.txt')
train = LabelledCollection.load(T1B_trainpath, load_binary_vectors)
nF = train.instances.shape[1]
qp.environ['SAMPLE_SIZE'] = constants.T1B_SAMPLE_SIZE
print(f'number of classes: {len(train.classes_)}')
print(f'number of training documents: {len(train)}')
print(f'training prevalence: {F.strprev(train.prevalence())}')
print(f'training matrix shape: {train.instances.shape}')
true_prevalence = ResultSubmission.load(T1B_devprevalence_path)
2021-10-13 20:36:53 +02:00
cat2code, categories = load_category_map(T1B_catmap)
for quantifier in [PACC]: # [CC, ACC, PCC, PACC, EMQ]:
2021-10-13 20:36:53 +02:00
classifier = CalibratedClassifierCV(LogisticRegression())
model = quantifier(classifier).fit(train)
quantifier_name = model.__class__.__name__
predictions = ResultSubmission(categories=categories)
for samplename, sample in tqdm(gen_load_samples_T1(T1B_devvectors_path, nF),
desc=quantifier_name, total=len(true_prevalence)):
predictions.add(samplename, model.quantify(sample))
2021-10-13 20:36:53 +02:00
predictions.dump(os.path.join(predictions_path, quantifier_name + '.csv'))
mae, mrae = evaluate_submission(true_prevalence, predictions)
print(f'{quantifier_name} mae={mae:.3f} mrae={mrae:.3f}')
2021-10-13 20:36:53 +02:00