1
0
Fork 0

trying ordinal classification

This commit is contained in:
Alejandro Moreo Fernandez 2022-03-08 16:27:41 +01:00
parent f285e936ad
commit b982a51103
3 changed files with 63 additions and 14 deletions

View File

@ -1,35 +1,74 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import quapy as qp
from method.aggregative import PACC, CC, EMQ
import numpy as np
from Ordinal.model import OrderedLogisticRegression, StackedClassifier, RegressionQuantification, RegressorClassifier
from quapy.method.aggregative import PACC, CC, EMQ, PCC, ACC
from quapy.data import LabelledCollection
from os.path import join
from utils import load_samples
from evaluation import nmd
from utils import load_samples, load_samples_pkl
from evaluation import nmd, mnmd
from time import time
import pickle
from tqdm import tqdm
domain = 'Books'
domain = 'Books-tfidf'
datapath = './data'
protocol = 'app'
drift = 'low'
tfidf = TfidfVectorizer(sublinear_tf=True, min_df=5, ngram_range=(1, 2))
train = LabelledCollection.load(join(datapath, domain, 'training_data.txt'), loader_func=qp.data.reader.from_text)
train.instances = tfidf.fit_transform(train.instances)
train = pickle.load(open(join(datapath, domain, 'training_data.pkl'), 'rb'))
def load_test_samples():
for sample in load_samples(join(datapath, domain, protocol, 'test_samples'), classes=train.classes_):
sample.instances = tfidf.transform(sample.instances)
ids = np.load(join(datapath, domain, protocol, f'{drift}drift.test.id.npy'))
ids = set(ids)
for sample in tqdm(load_samples_pkl(join(datapath, domain, protocol, 'test_samples'), filter=ids), total=len(ids)):
yield sample.instances, sample.prevalence()
q = EMQ(LogisticRegression())
def load_dev_samples():
ids = np.load(join(datapath, domain, protocol, f'{drift}drift.dev.id.npy'))
ids = set(ids)
for sample in tqdm(load_samples_pkl(join(datapath, domain, protocol, 'dev_samples'), filter=ids), total=len(ids)):
yield sample.instances, sample.prevalence()
print('fitting the quantifier')
# q = PACC(LogisticRegression(class_weight='balanced'))
# q = PACC(OrderedLogisticRegression())
# q = PACC(StackedClassifier(LogisticRegression(class_weight='balanced')))
# q = RegressionQuantification(PCC(LogisticRegression(class_weight='balanced')), val_samples_generator=load_dev_samples)
q = PACC(RegressorClassifier())
q = qp.model_selection.GridSearchQ(
q,
# {'C': np.logspace(-3,3,7), 'class_weight': [None, 'balanced']},
{'C': np.logspace(-3,3,14)},
1000,
'gen',
error=mnmd,
val_split=load_dev_samples,
n_jobs=-1,
refit=False,
verbose=True)
q.fit(train)
report = qp.evaluation.gen_prevalence_report(q, gen_fn=load_test_samples, error_metrics=[nmd], eval_budget=100)
print('[done]')
report = qp.evaluation.gen_prevalence_report(q, gen_fn=load_test_samples, error_metrics=[nmd])
mean_nmd = report['nmd'].mean()
std_nmd = report['nmd'].std()
print(f'{mean_nmd:.4f} +-{std_nmd:.4f}')
# drift='high'
# report = qp.evaluation.gen_prevalence_report(q, gen_fn=load_test_samples, error_metrics=[nmd])
# mean_nmd = report['nmd'].mean()
# std_nmd = report['nmd'].std()
# print(f'{mean_nmd:.4f} +-{std_nmd:.4f}')

View File

@ -3,6 +3,7 @@ from quapy.data import LabelledCollection
from glob import glob
import os
from os.path import join
import pickle
def load_samples(path_dir, classes):
@ -11,3 +12,11 @@ def load_samples(path_dir, classes):
yield LabelledCollection.load(join(path_dir, f'{id}.txt'), loader_func=qp.data.reader.from_text, classes=classes)
def load_samples_pkl(path_dir, filter=None):
nsamples = len(glob(join(path_dir, f'*.pkl')))
for id in range(nsamples):
if filter is not None:
if id not in filter:
continue
yield pickle.load(open(join(path_dir, f'{id}.pkl'), 'rb'))

View File

@ -1,9 +1,10 @@
import numpy as np
from scipy.sparse import dok_matrix
from tqdm import tqdm
from time import time
def from_text(path, encoding='utf-8', verbose=1, class2int=True):
def from_text(path, encoding='utf-8', verbose=0, class2int=True):
"""
Reads a labelled colletion of documents.
File fomart <0 or 1>\t<document>\n