adding QuaNet to experiments of Twitter; trying new stuff in 'NewMethods'
This commit is contained in:
parent
b30c40b7a0
commit
1399125fb8
|
@ -0,0 +1,37 @@
|
|||
from typing import Union
|
||||
import quapy as qp
|
||||
from quapy.method.aggregative import PACC, EMQ, HDy
|
||||
import quapy.functional as F
|
||||
|
||||
|
||||
class PACCSLD(PACC):
|
||||
"""
|
||||
This method combines the EMQ improved posterior probabilities with PACC.
|
||||
Note: the posterior probabilities are re-calibrated with EMQ only during prediction, and not also during fit since,
|
||||
for PACC, the validation split is known to have the same prevalence as the training set (this is because the split
|
||||
is stratified) and thus the posterior probabilities should not be re-calibrated for a different prior (it actually
|
||||
happens to degrades performance).
|
||||
"""
|
||||
|
||||
def fit(self, data: qp.data.LabelledCollection, fit_learner=True, val_split:Union[float, int, qp.data.LabelledCollection]=0.4):
|
||||
self.train_prevalence = F.prevalence_from_labels(data.labels, data.n_classes)
|
||||
return super(PACCSLD, self).fit(data, fit_learner, val_split)
|
||||
|
||||
def aggregate(self, classif_posteriors):
|
||||
priors, posteriors = EMQ.EM(self.train_prevalence, classif_posteriors, epsilon=1e-4)
|
||||
return super(PACCSLD, self).aggregate(posteriors)
|
||||
|
||||
|
||||
class HDySLD(HDy):
|
||||
"""
|
||||
This method combines the EMQ improved posterior probabilities with HDy.
|
||||
Note: [same as PACCSLD]
|
||||
"""
|
||||
def fit(self, data: qp.data.LabelledCollection, fit_learner=True,
|
||||
val_split: Union[float, int, qp.data.LabelledCollection] = 0.4):
|
||||
self.train_prevalence = F.prevalence_from_labels(data.labels, data.n_classes)
|
||||
return super(HDySLD, self).fit(data, fit_learner, val_split)
|
||||
|
||||
def aggregate(self, classif_posteriors):
|
||||
priors, posteriors = EMQ.EM(self.train_prevalence, classif_posteriors, epsilon=1e-4)
|
||||
return super(HDySLD, self).aggregate(posteriors)
|
|
@ -0,0 +1,48 @@
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
from classification.methods import PCALR
|
||||
from method.meta import QuaNet
|
||||
from quapy.method.aggregative import *
|
||||
from NewMethods.methods import *
|
||||
from experiments import run, SAMPLE_SIZE
|
||||
import numpy as np
|
||||
import itertools
|
||||
from joblib import Parallel, delayed
|
||||
import settings
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
|
||||
parser.add_argument('results', metavar='RESULT_PATH', type=str, help='path to the directory where to store the results')
|
||||
parser.add_argument('svmperfpath', metavar='SVMPERF_PATH', type=str, help='path to the directory with svmperf')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def quantification_models():
|
||||
def newLR():
|
||||
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
||||
__C_range = np.logspace(-4, 5, 10)
|
||||
lr_params = {'C': __C_range, 'class_weight': [None, 'balanced']}
|
||||
svmperf_params = {'C': __C_range}
|
||||
#yield 'paccsld', PACCSLD(newLR()), lr_params
|
||||
#yield 'hdysld', OneVsAll(HDySLD(newLR())), lr_params # <-- promising!
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), SAMPLE_SIZE, device=device), lr_params
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
print(f'Result folder: {args.results}')
|
||||
np.random.seed(0)
|
||||
|
||||
optim_losses = ['mae']
|
||||
datasets = ['hcr'] # qp.datasets.TWITTER_SENTIMENT_DATASETS_TRAIN
|
||||
models = quantification_models()
|
||||
|
||||
results = Parallel(n_jobs=settings.N_JOBS)(
|
||||
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
import quapy as qp
|
||||
import numpy as np
|
||||
from os import makedirs
|
||||
import sys, os
|
||||
import pickle
|
||||
from experiments import result_path
|
||||
from tabular import Table
|
||||
|
||||
tables_path = './tables'
|
||||
MAXTONE = 50 # sets the intensity of the maximum color reached by the worst (red) and best (green) results
|
||||
|
||||
makedirs(tables_path, exist_ok=True)
|
||||
|
||||
sample_size = 100
|
||||
qp.environ['SAMPLE_SIZE'] = sample_size
|
||||
|
||||
|
||||
nice = {
|
||||
'mae':'AE',
|
||||
'mrae':'RAE',
|
||||
'ae':'AE',
|
||||
'rae':'RAE',
|
||||
'svmkld': 'SVM(KLD)',
|
||||
'svmnkld': 'SVM(NKLD)',
|
||||
'svmq': 'SVM(Q)',
|
||||
'svmae': 'SVM(AE)',
|
||||
'svmnae': 'SVM(NAE)',
|
||||
'svmmae': 'SVM(AE)',
|
||||
'svmmrae': 'SVM(RAE)',
|
||||
'quanet': 'QuaNet',
|
||||
'hdy': 'HDy',
|
||||
'dys': 'DyS',
|
||||
'svmperf':'',
|
||||
'sanders': 'Sanders',
|
||||
'semeval13': 'SemEval13',
|
||||
'semeval14': 'SemEval14',
|
||||
'semeval15': 'SemEval15',
|
||||
'semeval16': 'SemEval16',
|
||||
'Average': 'Average'
|
||||
}
|
||||
|
||||
|
||||
def nicerm(key):
|
||||
return '\mathrm{'+nice[key]+'}'
|
||||
|
||||
|
||||
def load_Gao_Sebastiani_previous_results():
|
||||
def rename(method):
|
||||
old2new = {
|
||||
'kld': 'svmkld',
|
||||
'nkld': 'svmnkld',
|
||||
'qbeta2': 'svmq',
|
||||
'em': 'sld'
|
||||
}
|
||||
return old2new.get(method, method)
|
||||
|
||||
gao_seb_results = {}
|
||||
with open('./Gao_Sebastiani_results.txt', 'rt') as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines[1:]:
|
||||
line = line.strip()
|
||||
parts = line.lower().split()
|
||||
if len(parts) == 4:
|
||||
dataset, method, ae, rae = parts
|
||||
else:
|
||||
method, ae, rae = parts
|
||||
learner, method = method.split('-')
|
||||
method = rename(method)
|
||||
gao_seb_results[f'{dataset}-{method}-ae'] = float(ae)
|
||||
gao_seb_results[f'{dataset}-{method}-rae'] = float(rae)
|
||||
return gao_seb_results
|
||||
|
||||
|
||||
def get_ranks_from_Gao_Sebastiani():
|
||||
gao_seb_results = load_Gao_Sebastiani_previous_results()
|
||||
datasets = set([key.split('-')[0] for key in gao_seb_results.keys()])
|
||||
methods = np.sort(np.unique([key.split('-')[1] for key in gao_seb_results.keys()]))
|
||||
ranks = {}
|
||||
for metric in ['ae', 'rae']:
|
||||
for dataset in datasets:
|
||||
scores = [gao_seb_results[f'{dataset}-{method}-{metric}'] for method in methods]
|
||||
order = np.argsort(scores)
|
||||
sorted_methods = methods[order]
|
||||
for i, method in enumerate(sorted_methods):
|
||||
ranks[f'{dataset}-{method}-{metric}'] = i+1
|
||||
for method in methods:
|
||||
rankave = np.mean([ranks[f'{dataset}-{method}-{metric}'] for dataset in datasets])
|
||||
ranks[f'Average-{method}-{metric}'] = rankave
|
||||
return ranks, gao_seb_results
|
||||
|
||||
|
||||
def save_table(path, table):
|
||||
print(f'saving results in {path}')
|
||||
with open(path, 'wt') as foo:
|
||||
foo.write(table)
|
||||
|
||||
|
||||
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST
|
||||
evaluation_measures = [qp.error.ae, qp.error.rae]
|
||||
gao_seb_methods = ['cc', 'acc', 'pcc', 'pacc', 'sld', 'svmq', 'svmkld', 'svmnkld']
|
||||
new_methods = []
|
||||
|
||||
|
||||
def experiment_errors(dataset, method, loss):
|
||||
path = result_path(dataset, method, 'm'+loss if not loss.startswith('m') else loss)
|
||||
if os.path.exists(path):
|
||||
true_prevs, estim_prevs, _, _, _, _ = pickle.load(open(path, 'rb'))
|
||||
err_fn = getattr(qp.error, loss)
|
||||
errors = err_fn(true_prevs, estim_prevs)
|
||||
return errors
|
||||
return None
|
||||
|
||||
|
||||
gao_seb_ranks, gao_seb_results = get_ranks_from_Gao_Sebastiani()
|
||||
|
||||
for i, eval_func in enumerate(evaluation_measures):
|
||||
|
||||
# Tables evaluation scores for AE and RAE (two tables)
|
||||
# ----------------------------------------------------
|
||||
|
||||
eval_name = eval_func.__name__
|
||||
added_methods = ['svmm' + eval_name] + new_methods
|
||||
methods = gao_seb_methods + added_methods
|
||||
nold_methods = len(gao_seb_methods)
|
||||
nnew_methods = len(added_methods)
|
||||
|
||||
# fill data table
|
||||
table = Table(rows=datasets, cols=methods)
|
||||
for dataset in datasets:
|
||||
for method in methods:
|
||||
table.add(dataset, method, experiment_errors(dataset, method, eval_name))
|
||||
|
||||
# write the latex table
|
||||
tabular = """
|
||||
\\begin{tabularx}{\\textwidth}{|c||""" + ('Y|'*nold_methods)+ '|' + ('Y|'*nnew_methods) + """} \hline
|
||||
& \multicolumn{"""+str(nold_methods)+"""}{c||}{Methods tested in~\cite{Gao:2016uq}} &
|
||||
\multicolumn{"""+str(nnew_methods)+"""}{c|}{} \\\\ \hline
|
||||
"""
|
||||
rowreplace={dataset: nice.get(dataset, dataset.upper()) for dataset in datasets}
|
||||
colreplace={method:'\side{' + nice.get(method, method.upper()) +'$^{' + nicerm(eval_name) + '}$} ' for method in methods}
|
||||
|
||||
tabular += table.latexTabular(rowreplace=rowreplace, colreplace=colreplace)
|
||||
tabular += "\n\end{tabularx}"
|
||||
|
||||
save_table(f'./tables/tab_results_{eval_name}.new.tex', tabular)
|
||||
|
||||
# Tables ranks for AE and RAE (two tables)
|
||||
# ----------------------------------------------------
|
||||
methods = gao_seb_methods
|
||||
|
||||
# fill the data table
|
||||
ranktable = Table(rows=datasets, cols=methods, missing='--')
|
||||
for dataset in datasets:
|
||||
for method in methods:
|
||||
ranktable.add(dataset, method, values=table.get(dataset, method, 'rank'))
|
||||
|
||||
# write the latex table
|
||||
tabular = """
|
||||
\\begin{tabularx}{\\textwidth}{|c||""" + ('Y|' * len(gao_seb_methods)) + """} \hline
|
||||
& \multicolumn{""" + str(nold_methods) + """}{c|}{Methods tested in~\cite{Gao:2016uq}} \\\\ \hline
|
||||
"""
|
||||
for method in methods:
|
||||
tabular += ' & \side{' + nice.get(method, method.upper()) +'$^{' + nicerm(eval_name) + '}$} '
|
||||
tabular += '\\\\\hline\n'
|
||||
|
||||
for dataset in datasets:
|
||||
tabular += nice.get(dataset, dataset.upper()) + ' '
|
||||
for method in methods:
|
||||
newrank = ranktable.get(dataset, method)
|
||||
oldrank = gao_seb_ranks[f'{dataset}-{method}-{eval_name}']
|
||||
if newrank != '--':
|
||||
newrank = f'{int(newrank)}'
|
||||
color = ranktable.get_color(dataset, method)
|
||||
if color == '--':
|
||||
color = ''
|
||||
tabular += ' & ' + f'{newrank}' + f' ({oldrank}) ' + color
|
||||
tabular += '\\\\\hline\n'
|
||||
tabular += '\hline\n'
|
||||
|
||||
tabular += 'Average '
|
||||
for method in methods:
|
||||
newrank = ranktable.get_average(method)
|
||||
oldrank = gao_seb_ranks[f'Average-{method}-{eval_name}']
|
||||
if newrank != '--':
|
||||
newrank = f'{newrank:.1f}'
|
||||
oldrank = f'{oldrank:.1f}'
|
||||
color = ranktable.get_average(method, 'color')
|
||||
if color == '--':
|
||||
color = ''
|
||||
tabular += ' & ' + f'{newrank}' + f' ({oldrank}) ' + color
|
||||
tabular += '\\\\\hline\n'
|
||||
tabular += "\end{tabularx}"
|
||||
|
||||
save_table(f'./tables/tab_rank_{eval_name}.new.tex', tabular)
|
||||
|
||||
print("[Done]")
|
|
@ -0,0 +1,3 @@
|
|||
import multiprocessing
|
||||
|
||||
N_JOBS = -2 #multiprocessing.cpu_count()
|
|
@ -1,5 +1,7 @@
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
from classification.methods import PCALR
|
||||
from method.meta import QuaNet
|
||||
from quapy.method.aggregative import CC, ACC, PCC, PACC, EMQ, OneVsAll, SVMQ, SVMKLD, SVMNKLD, SVMAE, SVMRAE, HDy
|
||||
import quapy.functional as F
|
||||
import numpy as np
|
||||
|
@ -9,12 +11,19 @@ import itertools
|
|||
from joblib import Parallel, delayed
|
||||
import settings
|
||||
import argparse
|
||||
import torch
|
||||
import shutil
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run experiments for Tweeter Sentiment Quantification')
|
||||
parser.add_argument('results', metavar='RESULT_PATH', type=str, help='path to the directory where to store the results')
|
||||
parser.add_argument('svmperfpath', metavar='SVMPERF_PATH', type=str, help='path to the directory with svmperf')
|
||||
parser.add_argument('--svmperfpath', metavar='SVMPERF_PATH', type=str,default='./svm_perf_quantification',
|
||||
help='path to the directory with svmperf')
|
||||
parser.add_argument('--checkpointdir', metavar='PATH', type=str,default='./checkpoint',
|
||||
help='path to the directory where to dump QuaNet checkpoints')
|
||||
args = parser.parse_args()
|
||||
|
||||
SAMPLE_SIZE = 100
|
||||
|
||||
|
||||
def quantification_models():
|
||||
def newLR():
|
||||
|
@ -38,12 +47,15 @@ def quantification_models():
|
|||
yield 'svmmrae', OneVsAll(SVMRAE(args.svmperfpath)), svmperf_params
|
||||
yield 'hdy', OneVsAll(HDy(newLR())), lr_params
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), SAMPLE_SIZE, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
|
||||
# to add:
|
||||
# quapy
|
||||
# ensembles
|
||||
#
|
||||
|
||||
# 'mlpe': lambda learner: MaximumLikelihoodPrevalenceEstimation(),
|
||||
# 'mlpe': lambda learner: MaximumLikelihoodPrevalenceEstimation(),
|
||||
|
||||
|
||||
def evaluate_experiment(true_prevalences, estim_prevalences):
|
||||
|
@ -83,8 +95,7 @@ def save_results(dataset_name, model_name, optim_loss, *results):
|
|||
|
||||
def run(experiment):
|
||||
|
||||
sample_size = 100
|
||||
qp.environ['SAMPLE_SIZE'] = sample_size
|
||||
qp.environ['SAMPLE_SIZE'] = SAMPLE_SIZE
|
||||
|
||||
optim_loss, dataset_name, (model_name, model, hyperparams) = experiment
|
||||
|
||||
|
@ -104,7 +115,7 @@ def run(experiment):
|
|||
model_selection = qp.model_selection.GridSearchQ(
|
||||
model,
|
||||
param_grid=hyperparams,
|
||||
sample_size=sample_size,
|
||||
sample_size=SAMPLE_SIZE,
|
||||
n_prevpoints=21,
|
||||
n_repetitions=5,
|
||||
error=optim_loss,
|
||||
|
@ -126,7 +137,7 @@ def run(experiment):
|
|||
true_prevalences, estim_prevalences = qp.evaluation.artificial_sampling_prediction(
|
||||
model,
|
||||
test=benchmark_eval.test,
|
||||
sample_size=sample_size,
|
||||
sample_size=SAMPLE_SIZE,
|
||||
n_prevpoints=21,
|
||||
n_repetitions=25
|
||||
)
|
||||
|
@ -154,4 +165,6 @@ if __name__ == '__main__':
|
|||
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
||||
)
|
||||
|
||||
shutil.rmtree(args.checkpointdir, ignore_errors=True)
|
||||
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ def save_table(path, table):
|
|||
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST
|
||||
evaluation_measures = [qp.error.ae, qp.error.rae]
|
||||
gao_seb_methods = ['cc', 'acc', 'pcc', 'pacc', 'sld', 'svmq', 'svmkld', 'svmnkld']
|
||||
new_methods = []
|
||||
new_methods = ['hdy']
|
||||
|
||||
|
||||
def experiment_errors(dataset, method, loss):
|
||||
|
|
|
@ -7,6 +7,7 @@ from . import evaluation
|
|||
from . import plot
|
||||
from . import util
|
||||
from . import model_selection
|
||||
from . import classification
|
||||
from quapy.method.base import isprobabilistic, isaggregative
|
||||
|
||||
|
||||
|
|
|
@ -3,10 +3,14 @@ from sklearn.linear_model import LogisticRegression
|
|||
|
||||
|
||||
class PCALR:
|
||||
"""
|
||||
An example of a classification method that also generates embedded inputs, as those required for QuaNet.
|
||||
This example simply combines a Principal Component Analysis (PCA) with Logistic Regression (LR).
|
||||
"""
|
||||
|
||||
def __init__(self, n_components=300, C=10, class_weight=None):
|
||||
def __init__(self, n_components=300, **kwargs):
|
||||
self.n_components = n_components
|
||||
self.learner = LogisticRegression(C=C, class_weight=class_weight, max_iter=1000)
|
||||
self.learner = LogisticRegression(**kwargs)
|
||||
|
||||
def get_params(self):
|
||||
params = {'n_components': self.n_components}
|
||||
|
@ -19,20 +23,20 @@ class PCALR:
|
|||
del params['n_components']
|
||||
self.learner.set_params(**params)
|
||||
|
||||
def fit(self, documents, labels):
|
||||
def fit(self, X, y):
|
||||
self.pca = TruncatedSVD(self.n_components)
|
||||
embedded = self.pca.fit_transform(documents, labels)
|
||||
self.learner.fit(embedded, labels)
|
||||
embedded = self.pca.fit_transform(X, y)
|
||||
self.learner.fit(embedded, y)
|
||||
self.classes_ = self.learner.classes_
|
||||
return self
|
||||
|
||||
def predict(self, documents):
|
||||
embedded = self.transform(documents)
|
||||
def predict(self, X):
|
||||
embedded = self.transform(X)
|
||||
return self.learner.predict(embedded)
|
||||
|
||||
def predict_proba(self, documents):
|
||||
embedded = self.transform(documents)
|
||||
def predict_proba(self, X):
|
||||
embedded = self.transform(X)
|
||||
return self.learner.predict_proba(embedded)
|
||||
|
||||
def transform(self, documents):
|
||||
return self.pca.transform(documents)
|
||||
def transform(self, X):
|
||||
return self.pca.transform(X)
|
||||
|
|
|
@ -17,7 +17,7 @@ def mae(prevs, prevs_hat):
|
|||
|
||||
|
||||
def ae(p, p_hat):
|
||||
assert p.shape == p_hat.shape, 'wrong shape'
|
||||
assert p.shape == p_hat.shape, f'wrong shape {p.shape} vs. {p_hat.shape}'
|
||||
return abs(p_hat-p).mean(axis=-1)
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.nn import MSELoss
|
||||
|
@ -18,12 +19,15 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
tr_iter_per_poch=200,
|
||||
va_iter_per_poch=21,
|
||||
lr=1e-3,
|
||||
lstm_hidden_size=64,
|
||||
lstm_nlayers=1,
|
||||
lstm_hidden_size=128,
|
||||
lstm_nlayers=2,
|
||||
ff_layers=[1024, 512],
|
||||
bidirectional=True,
|
||||
qdrop_p=0.5,
|
||||
patience=10, checkpointpath='../checkpoint/quanet.dat', device='cuda'):
|
||||
patience=10,
|
||||
checkpointdir='../checkpoint',
|
||||
checkpointname=None,
|
||||
device='cuda'):
|
||||
assert hasattr(learner, 'transform'), \
|
||||
f'the learner {learner.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
||||
f'since it does not implement the method "transform"'
|
||||
|
@ -45,8 +49,13 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
}
|
||||
|
||||
self.patience = patience
|
||||
self.checkpointpath = checkpointpath
|
||||
os.makedirs(Path(checkpointpath).parent, exist_ok=True)
|
||||
os.makedirs(checkpointdir, exist_ok=True)
|
||||
if checkpointname is None:
|
||||
local_random = random.Random()
|
||||
random_code = '-'.join(str(local_random.randint(0, 1000000)) for _ in range(5))
|
||||
checkpointname = 'QuaNet-'+random_code
|
||||
self.checkpointdir = checkpointdir
|
||||
self.checkpoint = os.path.join(checkpointdir, checkpointname)
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.__check_params_colision(self.quanet_params, self.learner.get_params())
|
||||
|
@ -102,7 +111,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
self.optim = torch.optim.Adam(self.quanet.parameters(), lr=self.lr)
|
||||
early_stop = EarlyStop(self.patience, lower_is_better=True)
|
||||
|
||||
checkpoint = self.checkpointpath
|
||||
checkpoint = self.checkpoint
|
||||
|
||||
for epoch_i in range(1, self.n_epochs):
|
||||
self.epoch(train_data, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
||||
|
@ -124,7 +133,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
label_predictions = np.argmax(posteriors, axis=-1)
|
||||
prevs_estim = []
|
||||
for quantifier in self.quantifiers.values():
|
||||
predictions = posteriors if isprobabilistic(quantifier) else label_predictions
|
||||
predictions = posteriors if quantifier.probabilistic else label_predictions
|
||||
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||
|
||||
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
|
||||
|
@ -139,7 +148,10 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
quant_estims = self.get_aggregative_estims(posteriors)
|
||||
self.quanet.eval()
|
||||
with torch.no_grad():
|
||||
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims).item()
|
||||
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims)
|
||||
if self.device == torch.device('cuda'):
|
||||
prevalence = prevalence.cpu()
|
||||
prevalence = prevalence.numpy().flatten()
|
||||
return prevalence
|
||||
|
||||
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
||||
|
@ -179,7 +191,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
|
||||
def set_params(self, **parameters):
|
||||
learner_params={}
|
||||
for key, val in parameters:
|
||||
for key, val in parameters.items():
|
||||
if key in self.quanet_params:
|
||||
self.quanet_params[key]=val
|
||||
else:
|
||||
|
@ -194,6 +206,14 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
raise ValueError(f'the use of parameters {intersection} is ambiguous sine those can refer to '
|
||||
f'the parameters of QuaNet or the learner {self.learner.__class__.__name__}')
|
||||
|
||||
def clean_checkpoint(self):
|
||||
os.remove(self.checkpoint)
|
||||
|
||||
def clean_checkpoint_dir(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
||||
|
||||
|
||||
|
||||
class QuaNetModule(torch.nn.Module):
|
||||
def __init__(self,
|
||||
|
@ -274,3 +294,5 @@ class QuaNetModule(torch.nn.Module):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue