fixing quanet
This commit is contained in:
parent
e1790b3d9d
commit
99132c8166
|
@ -17,6 +17,9 @@ import torch
|
|||
import shutil
|
||||
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def quantification_models():
|
||||
def newLR():
|
||||
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
||||
|
@ -25,23 +28,34 @@ def quantification_models():
|
|||
svmperf_params = {'C': __C_range}
|
||||
|
||||
# methods tested in Gao & Sebastiani 2016
|
||||
yield 'cc', CC(newLR()), lr_params
|
||||
yield 'acc', ACC(newLR()), lr_params
|
||||
yield 'pcc', PCC(newLR()), lr_params
|
||||
yield 'pacc', PACC(newLR()), lr_params
|
||||
yield 'sld', EMQ(newLR()), lr_params
|
||||
yield 'svmq', OneVsAll(SVMQ(args.svmperfpath)), svmperf_params
|
||||
yield 'svmkld', OneVsAll(SVMKLD(args.svmperfpath)), svmperf_params
|
||||
yield 'svmnkld', OneVsAll(SVMNKLD(args.svmperfpath)), svmperf_params
|
||||
|
||||
# methods added
|
||||
yield 'svmmae', OneVsAll(SVMAE(args.svmperfpath)), svmperf_params
|
||||
yield 'svmmrae', OneVsAll(SVMRAE(args.svmperfpath)), svmperf_params
|
||||
yield 'hdy', OneVsAll(HDy(newLR())), lr_params
|
||||
# yield 'cc', CC(newLR()), lr_params
|
||||
# yield 'acc', ACC(newLR()), lr_params
|
||||
# yield 'pcc', PCC(newLR()), lr_params
|
||||
# yield 'pacc', PACC(newLR()), lr_params
|
||||
# yield 'sld', EMQ(newLR()), lr_params
|
||||
# yield 'svmq', OneVsAll(SVMQ(args.svmperfpath)), svmperf_params
|
||||
# yield 'svmkld', OneVsAll(SVMKLD(args.svmperfpath)), svmperf_params
|
||||
# yield 'svmnkld', OneVsAll(SVMNKLD(args.svmperfpath)), svmperf_params
|
||||
#
|
||||
# # methods added
|
||||
# yield 'svmmae', OneVsAll(SVMAE(args.svmperfpath)), svmperf_params
|
||||
# yield 'svmmrae', OneVsAll(SVMRAE(args.svmperfpath)), svmperf_params
|
||||
# yield 'hdy', OneVsAll(HDy(newLR())), lr_params
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f'Running QuaNet in {device}')
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, tr_iter_per_poch=500, va_iter_per_poch=100, checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
if DEBUG:
|
||||
lr_params={'C':[1,10]}
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE,
|
||||
lstm_hidden_size=32, lstm_nlayers=1,
|
||||
tr_iter_per_poch=50, va_iter_per_poch=10,
|
||||
patience=3,
|
||||
checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
else:
|
||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE,
|
||||
patience=5,
|
||||
tr_iter_per_poch=500, va_iter_per_poch=100,
|
||||
checkpointdir=args.checkpointdir, device=device), lr_params
|
||||
|
||||
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
||||
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||
|
@ -123,13 +137,14 @@ def run(experiment):
|
|||
)
|
||||
model_selection.fit(benchmark_devel.training, benchmark_devel.test)
|
||||
model = model_selection.best_model()
|
||||
best_params=model_selection.best_params_
|
||||
best_params = model_selection.best_params_
|
||||
|
||||
# model evaluation
|
||||
test_names = [dataset_name] if dataset_name != 'semeval' else ['semeval13', 'semeval14', 'semeval15']
|
||||
for test_no, test_name in enumerate(test_names):
|
||||
benchmark_eval = qp.datasets.fetch_twitter(test_name, for_model_selection=False, min_df=5, pickle=True)
|
||||
if test_no == 0:
|
||||
print('fitting the selected model')
|
||||
# fits the model only the first time
|
||||
model.fit(benchmark_eval.training)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ def artificial_sampling_prediction(
|
|||
estim_prevalence = quantification_func(sample.instances)
|
||||
return true_prevalence, estim_prevalence
|
||||
|
||||
print('predicting')
|
||||
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
|
||||
results = Parallel(n_jobs=n_jobs)(
|
||||
delayed(_predict_prevalences)(index) for index in pbar
|
||||
|
|
|
@ -70,9 +70,22 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
:return: self
|
||||
"""
|
||||
# split: 40% for training classification, 40% for training quapy, and 20% for validating quapy
|
||||
self.learner, unused_data = \
|
||||
training_helper(self.learner, data, fit_learner, ensure_probabilistic=True, val_split=0.6)
|
||||
#self.learner, unused_data = \
|
||||
# training_helper(self.learner, data, fit_learner, ensure_probabilistic=True, val_split=0.6)
|
||||
classifier_data, unused_data = data.split_stratified(0.4)
|
||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||
self.learner.fit(*classifier_data.Xy)
|
||||
|
||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||
self.tr_prev = data.prevalence()
|
||||
|
||||
self.quantifiers = {
|
||||
'cc': CC(self.learner).fit(classifier_data, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(classifier_data, fit_learner=True),
|
||||
'pcc': PCC(self.learner).fit(classifier_data, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(classifier_data, fit_learner=True),
|
||||
'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
||||
}
|
||||
|
||||
# compute the posterior probabilities of the instances
|
||||
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
||||
|
@ -82,17 +95,6 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
valid_data.instances = self.learner.transform(valid_data.instances)
|
||||
train_data.instances = self.learner.transform(train_data.instances)
|
||||
|
||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||
self.tr_prev = data.prevalence()
|
||||
|
||||
self.quantifiers = {
|
||||
'cc': CC(self.learner).fit(data, fit_learner=False),
|
||||
'acc': ACC(self.learner).fit(data, fit_learner=False),
|
||||
'pcc': PCC(self.learner).fit(data, fit_learner=False),
|
||||
'pacc': PACC(self.learner).fit(data, fit_learner=False),
|
||||
'emq': EMQ(self.learner).fit(data, fit_learner=False),
|
||||
}
|
||||
|
||||
self.status = {
|
||||
'tr-loss': -1,
|
||||
'va-loss': -1,
|
||||
|
@ -124,7 +126,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
||||
f'for epoch {early_stop.best_epoch}')
|
||||
self.quanet.load_state_dict(torch.load(checkpoint))
|
||||
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
||||
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
||||
break
|
||||
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue