1
0
Fork 0
QuaPy/LeQua2022/baselines_T1Amodsel.py

82 lines
2.3 KiB
Python

import pickle
from sklearn.linear_model import LogisticRegression
from quapy.method.aggregative import *
import quapy.functional as F
from data import *
import os
import constants
# LeQua official baselines for task T1A (Binary/Vector)
# =====================================================
predictions_path = os.path.join('predictions', 'T1A')
os.makedirs(predictions_path, exist_ok=True)
models_path = os.path.join('models', 'T1A')
os.makedirs(models_path, exist_ok=True)
pathT1A = './data/T1A/public'
T1A_devvectors_path = os.path.join(pathT1A, 'dev_vectors')
T1A_devprevalence_path = os.path.join(pathT1A, 'dev_prevalences.csv')
T1A_trainpath = os.path.join(pathT1A, 'training_vectors.txt')
train = LabelledCollection.load(T1A_trainpath, load_binary_vectors)
nF = train.instances.shape[1]
qp.environ['SAMPLE_SIZE'] = constants.T1A_SAMPLE_SIZE
print(f'number of classes: {len(train.classes_)}')
print(f'number of training documents: {len(train)}')
print(f'training prevalence: {F.strprev(train.prevalence())}')
print(f'training matrix shape: {train.instances.shape}')
true_prevalence = ResultSubmission.load(T1A_devprevalence_path)
param_grid = {
'C': np.logspace(-3,3,7),
'class_weight': ['balanced', None]
}
def gen_samples():
return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_id=False)
for quantifier in [EMQ]: # [CC, ACC, PCC, PACC, EMQ, HDy]:
if quantifier == EMQ:
classifier = CalibratedClassifierCV(LogisticRegression(), n_jobs=-1)
else:
classifier = LogisticRegression()
model = quantifier(classifier)
print(f'{model.__class__.__name__}: Model selection')
model = qp.model_selection.GridSearchQ(
model,
param_grid,
sample_size=None,
protocol='gen',
error=qp.error.mae,
refit=False,
verbose=True
).fit(train, gen_samples)
quantifier_name = model.best_model().__class__.__name__
print(f'{quantifier_name} mae={model.best_score_:.3f} (params: {model.best_params_})')
pickle.dump(model.best_model(),
open(os.path.join(models_path, quantifier_name+'.pkl'), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL)
"""
validation
CC 0.1862 1.9587
ACC 0.0394 0.2669
PCC 0.1789 2.1383
PACC 0.0354 0.1587
EMQ 0.0224 0.0960
HDy 0.0467 0.2121
"""