fixing quanet
This commit is contained in:
parent
e1790b3d9d
commit
99132c8166
|
@ -17,6 +17,9 @@ import torch
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
def quantification_models():
|
def quantification_models():
|
||||||
def newLR():
|
def newLR():
|
||||||
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
return LogisticRegression(max_iter=1000, solver='lbfgs', n_jobs=-1)
|
||||||
|
@ -25,23 +28,34 @@ def quantification_models():
|
||||||
svmperf_params = {'C': __C_range}
|
svmperf_params = {'C': __C_range}
|
||||||
|
|
||||||
# methods tested in Gao & Sebastiani 2016
|
# methods tested in Gao & Sebastiani 2016
|
||||||
yield 'cc', CC(newLR()), lr_params
|
# yield 'cc', CC(newLR()), lr_params
|
||||||
yield 'acc', ACC(newLR()), lr_params
|
# yield 'acc', ACC(newLR()), lr_params
|
||||||
yield 'pcc', PCC(newLR()), lr_params
|
# yield 'pcc', PCC(newLR()), lr_params
|
||||||
yield 'pacc', PACC(newLR()), lr_params
|
# yield 'pacc', PACC(newLR()), lr_params
|
||||||
yield 'sld', EMQ(newLR()), lr_params
|
# yield 'sld', EMQ(newLR()), lr_params
|
||||||
yield 'svmq', OneVsAll(SVMQ(args.svmperfpath)), svmperf_params
|
# yield 'svmq', OneVsAll(SVMQ(args.svmperfpath)), svmperf_params
|
||||||
yield 'svmkld', OneVsAll(SVMKLD(args.svmperfpath)), svmperf_params
|
# yield 'svmkld', OneVsAll(SVMKLD(args.svmperfpath)), svmperf_params
|
||||||
yield 'svmnkld', OneVsAll(SVMNKLD(args.svmperfpath)), svmperf_params
|
# yield 'svmnkld', OneVsAll(SVMNKLD(args.svmperfpath)), svmperf_params
|
||||||
|
#
|
||||||
# methods added
|
# # methods added
|
||||||
yield 'svmmae', OneVsAll(SVMAE(args.svmperfpath)), svmperf_params
|
# yield 'svmmae', OneVsAll(SVMAE(args.svmperfpath)), svmperf_params
|
||||||
yield 'svmmrae', OneVsAll(SVMRAE(args.svmperfpath)), svmperf_params
|
# yield 'svmmrae', OneVsAll(SVMRAE(args.svmperfpath)), svmperf_params
|
||||||
yield 'hdy', OneVsAll(HDy(newLR())), lr_params
|
# yield 'hdy', OneVsAll(HDy(newLR())), lr_params
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
print(f'Running QuaNet in {device}')
|
print(f'Running QuaNet in {device}')
|
||||||
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE, tr_iter_per_poch=500, va_iter_per_poch=100, checkpointdir=args.checkpointdir, device=device), lr_params
|
if DEBUG:
|
||||||
|
lr_params={'C':[1,10]}
|
||||||
|
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE,
|
||||||
|
lstm_hidden_size=32, lstm_nlayers=1,
|
||||||
|
tr_iter_per_poch=50, va_iter_per_poch=10,
|
||||||
|
patience=3,
|
||||||
|
checkpointdir=args.checkpointdir, device=device), lr_params
|
||||||
|
else:
|
||||||
|
yield 'quanet', QuaNet(PCALR(**newLR().get_params()), settings.SAMPLE_SIZE,
|
||||||
|
patience=5,
|
||||||
|
tr_iter_per_poch=500, va_iter_per_poch=100,
|
||||||
|
checkpointdir=args.checkpointdir, device=device), lr_params
|
||||||
|
|
||||||
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
param_mod_sel={'sample_size':settings.SAMPLE_SIZE, 'n_prevpoints':21, 'n_repetitions':5}
|
||||||
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
#yield 'epaccmaeptr', EPACC(newLR(), param_grid=lr_params, optim='mae', policy='ptr', param_mod_sel=param_mod_sel, n_jobs=settings.ENSEMBLE_N_JOBS), None
|
||||||
|
@ -123,13 +137,14 @@ def run(experiment):
|
||||||
)
|
)
|
||||||
model_selection.fit(benchmark_devel.training, benchmark_devel.test)
|
model_selection.fit(benchmark_devel.training, benchmark_devel.test)
|
||||||
model = model_selection.best_model()
|
model = model_selection.best_model()
|
||||||
best_params=model_selection.best_params_
|
best_params = model_selection.best_params_
|
||||||
|
|
||||||
# model evaluation
|
# model evaluation
|
||||||
test_names = [dataset_name] if dataset_name != 'semeval' else ['semeval13', 'semeval14', 'semeval15']
|
test_names = [dataset_name] if dataset_name != 'semeval' else ['semeval13', 'semeval14', 'semeval15']
|
||||||
for test_no, test_name in enumerate(test_names):
|
for test_no, test_name in enumerate(test_names):
|
||||||
benchmark_eval = qp.datasets.fetch_twitter(test_name, for_model_selection=False, min_df=5, pickle=True)
|
benchmark_eval = qp.datasets.fetch_twitter(test_name, for_model_selection=False, min_df=5, pickle=True)
|
||||||
if test_no == 0:
|
if test_no == 0:
|
||||||
|
print('fitting the selected model')
|
||||||
# fits the model only the first time
|
# fits the model only the first time
|
||||||
model.fit(benchmark_eval.training)
|
model.fit(benchmark_eval.training)
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ def artificial_sampling_prediction(
|
||||||
estim_prevalence = quantification_func(sample.instances)
|
estim_prevalence = quantification_func(sample.instances)
|
||||||
return true_prevalence, estim_prevalence
|
return true_prevalence, estim_prevalence
|
||||||
|
|
||||||
|
print('predicting')
|
||||||
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
|
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
|
||||||
results = Parallel(n_jobs=n_jobs)(
|
results = Parallel(n_jobs=n_jobs)(
|
||||||
delayed(_predict_prevalences)(index) for index in pbar
|
delayed(_predict_prevalences)(index) for index in pbar
|
||||||
|
|
|
@ -70,9 +70,22 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
# split: 40% for training classification, 40% for training quapy, and 20% for validating quapy
|
# split: 40% for training classification, 40% for training quapy, and 20% for validating quapy
|
||||||
self.learner, unused_data = \
|
#self.learner, unused_data = \
|
||||||
training_helper(self.learner, data, fit_learner, ensure_probabilistic=True, val_split=0.6)
|
# training_helper(self.learner, data, fit_learner, ensure_probabilistic=True, val_split=0.6)
|
||||||
|
classifier_data, unused_data = data.split_stratified(0.4)
|
||||||
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
|
||||||
|
self.learner.fit(*classifier_data.Xy)
|
||||||
|
|
||||||
|
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||||
|
self.tr_prev = data.prevalence()
|
||||||
|
|
||||||
|
self.quantifiers = {
|
||||||
|
'cc': CC(self.learner).fit(classifier_data, fit_learner=False),
|
||||||
|
'acc': ACC(self.learner).fit(classifier_data, fit_learner=True),
|
||||||
|
'pcc': PCC(self.learner).fit(classifier_data, fit_learner=False),
|
||||||
|
'pacc': PACC(self.learner).fit(classifier_data, fit_learner=True),
|
||||||
|
'emq': EMQ(self.learner).fit(classifier_data, fit_learner=False),
|
||||||
|
}
|
||||||
|
|
||||||
# compute the posterior probabilities of the instances
|
# compute the posterior probabilities of the instances
|
||||||
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
valid_posteriors = self.learner.predict_proba(valid_data.instances)
|
||||||
|
@ -82,17 +95,6 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
valid_data.instances = self.learner.transform(valid_data.instances)
|
valid_data.instances = self.learner.transform(valid_data.instances)
|
||||||
train_data.instances = self.learner.transform(train_data.instances)
|
train_data.instances = self.learner.transform(train_data.instances)
|
||||||
|
|
||||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
|
||||||
self.tr_prev = data.prevalence()
|
|
||||||
|
|
||||||
self.quantifiers = {
|
|
||||||
'cc': CC(self.learner).fit(data, fit_learner=False),
|
|
||||||
'acc': ACC(self.learner).fit(data, fit_learner=False),
|
|
||||||
'pcc': PCC(self.learner).fit(data, fit_learner=False),
|
|
||||||
'pacc': PACC(self.learner).fit(data, fit_learner=False),
|
|
||||||
'emq': EMQ(self.learner).fit(data, fit_learner=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.status = {
|
self.status = {
|
||||||
'tr-loss': -1,
|
'tr-loss': -1,
|
||||||
'va-loss': -1,
|
'va-loss': -1,
|
||||||
|
@ -124,7 +126,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
print(f'training ended by patience exhausted; loading best model parameters in {checkpoint} '
|
||||||
f'for epoch {early_stop.best_epoch}')
|
f'for epoch {early_stop.best_epoch}')
|
||||||
self.quanet.load_state_dict(torch.load(checkpoint))
|
self.quanet.load_state_dict(torch.load(checkpoint))
|
||||||
self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
#self.epoch(valid_data, valid_posteriors, self.va_iter, epoch_i, early_stop, train=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
Loading…
Reference in New Issue