improved ReadMe method

This commit is contained in:
Alejandro Moreo Fernandez 2025-10-22 18:51:35 +02:00
parent 854b3ba3f9
commit c11b99e08a
4 changed files with 123 additions and 31 deletions

View File

@ -2,6 +2,7 @@ Change Log 0.2.1
-----------------
- Improved documentation of confidence regions.
- Added ReadMe method by Daniel Hopkins and Gary King
Change Log 0.2.0
-----------------

View File

@ -1,18 +1,55 @@
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectKBest, chi2
import quapy as qp
from quapy.method.non_aggregative import ReadMe
import quapy.functional as F
from sklearn.pipeline import Pipeline
"""
This example showcases how to use the non-aggregative method ReadMe proposed by Hopkins and King.
This method is for text analysis, so let us first instantiate a dataset for sentiment quantification (we
use IMDb for this example). The method is quite computationally expensive, so we will restrict the training
set to 1000 documents only.
"""
reviews = qp.datasets.fetch_reviews('imdb').reduce(n_train=1000, random_state=0)
encode_0_1 = CountVectorizer(min_df=5, binary=True)
"""
We need to convert text to bag-of-words representations. Actually, ReadMe requires the representations to be
binary (i.e., storing a 1 whenever a document contains certain word, or 0 otherwise), so we will not use
TFIDF weighting. We will also retain the top 1000 most important features according to chi2.
"""
encode_0_1 = Pipeline([
('0_1_terms', CountVectorizer(min_df=5, binary=True)),
('feat_sel', SelectKBest(chi2, k=1000))
])
train, test = qp.data.preprocessing.instance_transformation(reviews, encode_0_1, inplace=True).train_test
readme = ReadMe(bootstrap_trials=100, bagging_trials=100, bagging_range=100, random_state=0, verbose=True)
readme.fit(*train.Xy)
"""
We now instantiate ReadMe, with the prob_model='full' (default behaviour, implementing the Hopkins and King original
idea). This method consists of estimating Q(Y) by solving:
for test_prev in [[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]:
sample = reviews.test.sampling(500, *test_prev, random_state=0)
Q(X) = \sum_i Q(X|Y=i) Q(Y=i)
without resorting to estimating the posteriors Q(Y=i|X), by solving a linear least-squares problem.
However, since Q(X) and Q(X|Y=i) are matrices of shape (2^K, 1) and (2^K, n), with K the number of features
and n the number of classes, their calculation becomes intractable. ReadMe instead performs bagging (i.e., it
samples small sets of features and averages the results) thus reducing K to a few terms. In our example we
set K (bagging_range) to 20, and the number of bagging_trials to 100.
ReadMe also computes confidence intervals via bootstrap. We set the number of bootstrap trials to 100.
"""
readme = ReadMe(prob_model='full', bootstrap_trials=100, bagging_trials=100, bagging_range=20, random_state=0, verbose=True)
readme.fit(*train.Xy) # <- there is actually nothing happening here (only bootstrap resampling); the method is "lazy"
# and postpones most of the calculations to the test phase.
# since the method is slow, we will only test 3 cases with different imbalances
few_negatives = [0.25, 0.75]
balanced = [0.5, 0.5]
few_positives = [0.75, 0.25]
for test_prev in [few_negatives, balanced, few_positives]:
sample = reviews.test.sampling(500, *test_prev, random_state=0) # draw sets of 500 documents with desired prevs
prev_estim, conf = readme.predict_conf(sample.X)
err = qp.error.mae(sample.prevalence(), prev_estim)
print(f'true-prevalence={F.strprev(sample.prevalence())},\n'

View File

@ -22,8 +22,8 @@ def instance_transformation(dataset:Dataset, transformer, inplace=False):
:return: a new :class:`quapy.data.base.Dataset` with transformed instances (if inplace=False) or a reference to the
current Dataset (if inplace=True) where the instances have been transformed
"""
training_transformed = transformer.fit_transform(dataset.training.instances)
test_transformed = transformer.transform(dataset.test.instances)
training_transformed = transformer.fit_transform(*dataset.training.Xy)
test_transformed = transformer.transform(dataset.test.X)
if inplace:
dataset.training = LabelledCollection(training_transformed, dataset.training.labels, dataset.classes_)

View File

@ -1,4 +1,6 @@
from typing import Union, Callable
from itertools import product
from tqdm import tqdm
from typing import Union, Callable, Counter
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.utils import resample
@ -9,6 +11,7 @@ from quapy.functional import get_divergence
from quapy.method.base import BaseQuantifier, BinaryQuantifier
import quapy.functional as F
from scipy.optimize import lsq_linear
from scipy import sparse
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
@ -152,6 +155,8 @@ class DMx(BaseQuantifier):
return F.argmin_prevalence(loss, n_classes, method=self.search)
class ReadMe(BaseQuantifier, WithConfidenceABC):
"""
ReadMe is a non-aggregative quantification system proposed by
@ -168,9 +173,21 @@ class ReadMe(BaseQuantifier, WithConfidenceABC):
the feature space. ReadMe also combines bagging with bootstrap in order to derive confidence intervals
around point estimations.
:param bootstrap_trials: int, number of bootstrap trials (default 100)
:param bagging_trials: int, number of bagging trials (default 100)
:param bagging_range: int, number of features to keep for each bagging trial (default 250)
We use the same default parameters as in the official
`R implementation <https://github.com/iqss-research/ReadMeV1/blob/master/R/prototype.R>`_.
:param prob_model: str ('naive', or 'full'), selects the modality in which the probabilities `Q(X)` and
`Q(X|Y)` are to be modelled. Options include "full", which corresponds to the original formulation of
ReadMe, in which X is constrained to be a binary matrix (e.g., of term presence/absence) and in which
`Q(X)` and `Q(X|Y)` are modelled, respectively, as matrices of `(2^K, 1)` and `(2^K, n)` values, where
`K` is the number of columns in the data matrix (i.e., `bagging_range`), and `n` is the number of classes.
Of course, this approach is computationally prohibited for large `K`, so the computation is restricted to data
matrices with `K<=25` (although we recommend even smaller values of `K`). A much faster model is "naive", which
considers the `Q(X)` and `Q(X|Y)` be multinomial distributions under the `bag-of-words` perspective. In this
case, `bagging_range` can be set to much larger values. Default is "full" (i.e., original ReadMe behavior).
:param bootstrap_trials: int, number of bootstrap trials (default 300)
:param bagging_trials: int, number of bagging trials (default 300)
:param bagging_range: int, number of features to keep for each bagging trial (default 15)
:param confidence_level: float, a value in (0,1) reflecting the desired confidence level (default 0.95)
:param region: str in 'intervals', 'ellipse', 'ellipse-clr'; indicates the preferred method for
defining the confidence region (see :class:`WithConfidenceABC`)
@ -178,14 +195,21 @@ class ReadMe(BaseQuantifier, WithConfidenceABC):
:param verbose: bool, whether to display information during the process (default False)
"""
MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION = 25
PROBABILISTIC_MODELS = ["naive", "full"]
def __init__(self,
bootstrap_trials=100,
bagging_trials=100,
bagging_range=250,
prob_model="full",
bootstrap_trials=300,
bagging_trials=300,
bagging_range=15,
confidence_level=0.95,
region='intervals',
random_state=None,
verbose=False):
assert prob_model in ReadMe.PROBABILISTIC_MODELS, \
f'unknown {prob_model=}, valid ones are {ReadMe.PROBABILISTIC_MODELS=}'
self.prob_model = prob_model
self.bootstrap_trials = bootstrap_trials
self.bagging_trials = bagging_trials
self.bagging_range = bagging_range
@ -195,12 +219,11 @@ class ReadMe(BaseQuantifier, WithConfidenceABC):
self.verbose = verbose
def fit(self, X, y):
self._check_matrix(X)
self.rng = np.random.default_rng(self.random_state)
self.classes_ = np.unique(y)
n_features = X.shape[1]
if self.bagging_range is None:
self.bagging_range = int(np.sqrt(n_features))
Xsize = X.shape[0]
@ -214,11 +237,10 @@ class ReadMe(BaseQuantifier, WithConfidenceABC):
return self
def predict_conf(self, X, confidence_level=0.95) -> (np.ndarray, ConfidenceRegionABC):
from tqdm import tqdm
self._check_matrix(X)
n_features = X.shape[1]
boots_prevalences = []
for Xboots, yboots in tqdm(
zip(self.Xboots, self.yboots),
desc='bootstrap predictions', total=self.bootstrap_trials, disable=not self.verbose
@ -238,27 +260,59 @@ class ReadMe(BaseQuantifier, WithConfidenceABC):
return prev_estim, conf
def predict(self, X):
prev_estim, _ = self.predict_conf(X)
return prev_estim
def _quantify_iteration(self, Xtr, ytr, Xte):
"""Single ReadMe estimate."""
n_classes = len(self.classes_)
PX_given_Y = np.zeros((n_classes, Xtr.shape[1]))
for i, c in enumerate(self.classes_):
PX_given_Y[i] = Xtr[ytr == c].sum(axis=0)
PX_given_Y = normalize(PX_given_Y, norm='l1', axis=1)
PX_given_Y = np.asarray([self._compute_P(Xtr[ytr == c]) for i,c in enumerate(self.classes_)])
PX = self._compute_P(Xte)
PX = np.asarray(Xte.sum(axis=0))
PX = normalize(PX, norm='l1', axis=1)
res = lsq_linear(A=PX_given_Y.T, b=PX.ravel(), bounds=(0, 1))
res = lsq_linear(A=PX_given_Y.T, b=PX, bounds=(0, 1))
pY = np.maximum(res.x, 0)
return pY / pY.sum()
def _check_matrix(self, X):
"""the "full" model requires estimating empirical distributions; due to the high computational cost,
this function is only made available for binary matrices"""
if self.prob_model == 'full' and not self._is_binary_matrix(X):
raise ValueError('the empirical distribution can only be computed efficiently on binary matrices')
def _is_binary_matrix(self, X):
data = X.data if sparse.issparse(X) else X
return np.all((data == 0) | (data == 1))
def _compute_P(self, X):
if self.prob_model == 'naive':
return self._multinomial_distribution(X)
elif self.prob_model == 'full':
return self._empirical_distribution(X)
else:
raise ValueError(f'unknown {self.prob_model}; valid ones are {ReadMe.PROBABILISTIC_MODELS=}')
def _empirical_distribution(self, X):
if X.shape[1] > self.MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION:
raise ValueError(f'the empirical distribution can only be computed efficiently for dimensions '
f'less or equal than {self.MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION}')
# we convert every binary row (e.g., 0 0 1 0 1) into the equivalent number (e.g., 5)
K = X.shape[1]
binary_powers = 1 << np.arange(K-1, -1, -1) # (2^K, ..., 32, 16, 8, 4, 2, 1)
X_as_binary_numbers = X @ binary_powers
# count occurrences and compute probs
counts = np.bincount(X_as_binary_numbers, minlength=2 ** K).astype(float)
probs = counts / counts.sum()
return probs
def _multinomial_distribution(self, X):
PX = np.asarray(X.sum(axis=0))
PX = normalize(PX, norm='l1', axis=1)
return PX.ravel()
def _get_features_range(X):