adding linear layer training independently and svm comparison
This commit is contained in:
parent
e1047c2beb
commit
2b61722890
2
TODO.txt
2
TODO.txt
|
@ -9,12 +9,14 @@ Recap Feb. 2021:
|
||||||
- Projector trained via SCL + Classifier layer trained alone.
|
- Projector trained via SCL + Classifier layer trained alone.
|
||||||
- Projector trained via SCL + SVM Classifier.
|
- Projector trained via SCL + SVM Classifier.
|
||||||
- Projector trained via KTA + SVM Classifier.
|
- Projector trained via KTA + SVM Classifier.
|
||||||
|
- Comparator or Siamese networks for SAV + Classifier layer.
|
||||||
- Compare (SAV):
|
- Compare (SAV):
|
||||||
- My system (projector+binary-classifier layer)
|
- My system (projector+binary-classifier layer)
|
||||||
- Projector trained via SCL + Binary Classifier layer trained alone.
|
- Projector trained via SCL + Binary Classifier layer trained alone.
|
||||||
- Projector trained via SCL + SVM Classifier.
|
- Projector trained via SCL + SVM Classifier.
|
||||||
- Projector trained via KTA + SVM Classifier.
|
- Projector trained via KTA + SVM Classifier.
|
||||||
- Other systems (maybe Diff-Vectors, maybe Impostors, maybe distance-based)
|
- Other systems (maybe Diff-Vectors, maybe Impostors, maybe distance-based)
|
||||||
|
- Comparator or Siamese networks for SAV.
|
||||||
- Additional experiments:
|
- Additional experiments:
|
||||||
- show the kernel matrix
|
- show the kernel matrix
|
||||||
|
|
||||||
|
|
|
@ -135,9 +135,27 @@ class SupConLoss1View(nn.Module):
|
||||||
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
upper_diag = torch.triu_indices(batch_size,batch_size,+1)
|
||||||
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
cross_upper = cross[upper_diag[0], upper_diag[1]]
|
||||||
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
mask_upper = mask[upper_diag[0], upper_diag[1]]
|
||||||
pos = mask_upper.sum()
|
#pos = mask_upper.sum()
|
||||||
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
# weight = torch.from_numpy(np.asarray([1-pos, pos], dtype=float)).to(device)
|
||||||
return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
#return torch.nn.functional.binary_cross_entropy_with_logits(cross_upper, mask_upper)
|
||||||
|
#print('mask min-max:', mask.min(), mask.max())
|
||||||
|
#print('cross min-max:', cross.min(), cross.max())
|
||||||
|
#return torch.norm(cross-mask, p='fro') # <-- diagonal signal (trivial) should be too strong
|
||||||
|
pos_loss = mse(cross_upper, mask_upper, label=1)
|
||||||
|
neg_loss = mse(cross_upper, mask_upper, label=0)
|
||||||
|
#return neg_loss, pos_loss
|
||||||
|
#balanced_loss = pos_loss + neg_loss
|
||||||
|
#return balanced_loss
|
||||||
|
return torch.mean((cross_upper-mask_upper)**2), neg_loss, pos_loss
|
||||||
|
|
||||||
|
|
||||||
|
def mse(input, target, label):
|
||||||
|
input = input[target==label]
|
||||||
|
if label==0:
|
||||||
|
return torch.mean(input**2)
|
||||||
|
else:
|
||||||
|
return torch.mean((1-input)**2)
|
||||||
|
#return torch.mean((input[index] - target[index]) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,39 +171,37 @@ class SupConLoss1View(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # compute logits
|
||||||
|
# anchor_dot_contrast = torch.div(torch.matmul(features, features.T),self.temperature)
|
||||||
# compute logits
|
# # for numerical stability
|
||||||
anchor_dot_contrast = torch.div(torch.matmul(features, features.T),self.temperature)
|
# # logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
||||||
# for numerical stability
|
# # logits = anchor_dot_contrast - logits_max.detach()
|
||||||
# logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
# logits = anchor_dot_contrast
|
||||||
# logits = anchor_dot_contrast - logits_max.detach()
|
#
|
||||||
logits = anchor_dot_contrast
|
# # mask-out self-contrast cases
|
||||||
|
# # logits_mask = torch.scatter(
|
||||||
# mask-out self-contrast cases
|
# # torch.ones_like(mask),
|
||||||
# logits_mask = torch.scatter(
|
# # 1,
|
||||||
# torch.ones_like(mask),
|
# # torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
||||||
# 1,
|
# # 0
|
||||||
# torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
# # )
|
||||||
# 0
|
# # mask = mask * logits_mask
|
||||||
# )
|
# logits_mask = torch.ones_like(mask)
|
||||||
# mask = mask * logits_mask
|
# logits_mask.fill_diagonal_(0)
|
||||||
logits_mask = torch.ones_like(mask)
|
# mask.fill_diagonal_(0)
|
||||||
logits_mask.fill_diagonal_(0)
|
#
|
||||||
mask.fill_diagonal_(0)
|
# # compute log_prob
|
||||||
|
# exp_logits = torch.exp(logits) * logits_mask
|
||||||
# compute log_prob
|
# log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
||||||
exp_logits = torch.exp(logits) * logits_mask
|
#
|
||||||
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
# # compute mean of log-likelihood over positive
|
||||||
|
# div = mask.sum(1)
|
||||||
# compute mean of log-likelihood over positive
|
# div=torch.clamp(div, min=1)
|
||||||
div = mask.sum(1)
|
# mean_log_prob_pos = (mask * log_prob).sum(1) / div
|
||||||
div=torch.clamp(div, min=1)
|
#
|
||||||
mean_log_prob_pos = (mask * log_prob).sum(1) / div
|
# # loss
|
||||||
|
# loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
||||||
# loss
|
# # loss = loss.view(anchor_count, batch_size).mean()
|
||||||
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
# loss = loss.view(-1, batch_size).mean()
|
||||||
# loss = loss.view(anchor_count, batch_size).mean()
|
#
|
||||||
loss = loss.view(-1, batch_size).mean()
|
# return loss
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
39
src/main.py
39
src/main.py
|
@ -1,5 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from sklearn.model_selection import train_test_split, GridSearchCV
|
||||||
|
from sklearn.svm import LinearSVC
|
||||||
|
|
||||||
from data.AuthorshipDataset import AuthorshipDataset
|
from data.AuthorshipDataset import AuthorshipDataset
|
||||||
from data.fetch_blogs import Blogs
|
from data.fetch_blogs import Blogs
|
||||||
from data.fetch_imdb62 import Imdb62
|
from data.fetch_imdb62 import Imdb62
|
||||||
|
@ -94,22 +97,33 @@ def main(opt):
|
||||||
else:
|
else:
|
||||||
method = opt.name
|
method = opt.name
|
||||||
|
|
||||||
cls.supervised_contrastive_learning(Xtr, ytr,
|
if opt.mode=='savlin':
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
Xtr_, Xval_, ytr_, yval_ = train_test_split(Xtr, ytr, test_size=0.1, stratify=ytr)
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
cls.supervised_contrastive_learning(Xtr_, ytr_, Xval_, yval_,
|
||||||
checkpointpath=opt.checkpoint)
|
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||||
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint)
|
||||||
|
val_microf1 = cls.train_linear_classifier(Xtr_, ytr_, Xval_, yval_,
|
||||||
|
batch_size=opt.batchsize, epochs=opt.epochs, lr=opt.lr,
|
||||||
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint)
|
||||||
|
svm = GridSearchCV(LinearSVC(), param_grid={'C':np.logspace(-2,3,6), 'class_weight':['balanced',None]}, n_jobs=-1)
|
||||||
|
svm.fit(cls.project(Xtr), ytr)
|
||||||
|
yte_ = svm.predict(cls.project(Xte))
|
||||||
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
print(f'svm: acc={acc:.3f} macrof1={macrof1:.3f} microf1={microf1:.3f}')
|
||||||
|
elif opt.mode=='attr':
|
||||||
|
# train
|
||||||
|
val_microf1 = cls.fit(Xtr, ytr,
|
||||||
|
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||||
|
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||||
|
checkpointpath=opt.checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# train
|
|
||||||
val_microf1 = cls.fit(Xtr, ytr,
|
|
||||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
|
||||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
|
||||||
checkpointpath=opt.checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# test
|
# test
|
||||||
yte_ = cls.predict(Xte)
|
yte_ = cls.predict(Xte)
|
||||||
|
print('network prediction')
|
||||||
acc, macrof1, microf1 = evaluation(yte, yte_)
|
acc, macrof1, microf1 = evaluation(yte, yte_)
|
||||||
|
|
||||||
results = Results(opt.output)
|
results = Results(opt.output)
|
||||||
|
@ -174,6 +188,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-n', '--name', help='Name of the model', default='auto')
|
parser.add_argument('-n', '--name', help='Name of the model', default='auto')
|
||||||
requiredNamed = parser.add_argument_group('required named arguments')
|
requiredNamed = parser.add_argument_group('required named arguments')
|
||||||
requiredNamed.add_argument('-d', '--dataset', help='Name of the dataset', required=True, type=str)
|
requiredNamed.add_argument('-d', '--dataset', help='Name of the dataset', required=True, type=str)
|
||||||
|
requiredNamed.add_argument('-m', '--mode', help='training mode', choices=['attr', 'savlin'], required=True, type=str)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
assert opt.dataset in ['enron', 'imdb62', 'blogs', 'victorian'], 'unknown dataset'
|
assert opt.dataset in ['enron', 'imdb62', 'blogs', 'victorian'], 'unknown dataset'
|
||||||
|
|
|
@ -41,7 +41,7 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
with open(log, 'wt') as foo:
|
with open(log, 'wt') as foo:
|
||||||
print()
|
print()
|
||||||
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
||||||
tr_loss, val_loss = -1, -1
|
tr_loss = val_loss = acc = macrof1 = microf1 = -1
|
||||||
pbar = tqdm(range(1, epochs + 1))
|
pbar = tqdm(range(1, epochs + 1))
|
||||||
for epoch in pbar:
|
for epoch in pbar:
|
||||||
# training
|
# training
|
||||||
|
@ -93,12 +93,12 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
f'loss={tr_loss:.5f} '
|
f'loss={tr_loss:.5f} '
|
||||||
f'attr-loss={np.mean(attr_losses):.5f} '
|
f'attr-loss={np.mean(attr_losses):.5f} '
|
||||||
f'sav-loss={np.mean(sav_losses):.5f} '
|
f'sav-loss={np.mean(sav_losses):.5f} '
|
||||||
f'val_loss={val_loss:.5f} '
|
f'val_loss={val_loss:.5f} val_acc={acc:.4f} macrof1={macrof1:.4f} microf1={microf1:.4f}'
|
||||||
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
||||||
|
|
||||||
# validation
|
# validation
|
||||||
self.eval()
|
self.eval()
|
||||||
with torch.no_grad:
|
with torch.no_grad():
|
||||||
predictions, losses = [], []
|
predictions, losses = [], []
|
||||||
# for xi, yi in batcher_val.epoch(Xval, yval):
|
# for xi, yi in batcher_val.epoch(Xval, yval):
|
||||||
for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
||||||
|
@ -127,14 +127,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
|
||||||
def supervised_contrastive_learning(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def supervised_contrastive_learning(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
criterion = SupConLoss1View().to(self.device)
|
criterion = SupConLoss1View().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.projector.parameters(), lr=lr)
|
||||||
|
|
||||||
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
|
||||||
|
|
||||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||||
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||||
|
@ -142,53 +139,108 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
with open(log, 'wt') as foo:
|
with open(log, 'wt') as foo:
|
||||||
print()
|
print()
|
||||||
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
||||||
tr_loss, val_loss = -1, -1
|
tr_loss, val_loss, neg_losses_val, pos_losses_val = -1, -1, -1, -1
|
||||||
pbar = tqdm(range(1, epochs + 1))
|
pbar = tqdm(range(1, epochs + 1))
|
||||||
for epoch in pbar:
|
for epoch in pbar:
|
||||||
# training
|
# training
|
||||||
self.train()
|
self.train()
|
||||||
losses = []
|
losses, pos_losses, neg_losses = [], [], []
|
||||||
for xi, yi in tr_data.asDataLoader(batch_size, shuffle=True):
|
for xi, yi in tr_data.asDataLoader(batch_size, shuffle=True):
|
||||||
|
#while True:
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
phi = self.projector(xi)
|
phi = self.projector(xi)
|
||||||
contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
#contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||||
|
contrastive_loss, neg_loss, pos_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||||
|
#contrastive_loss = neg_loss+pos_loss
|
||||||
contrastive_loss.backward()
|
contrastive_loss.backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
losses.append(contrastive_loss.item())
|
losses.append(contrastive_loss.item())
|
||||||
|
neg_losses.append(neg_loss.item())
|
||||||
|
pos_losses.append(pos_loss.item())
|
||||||
tr_loss = np.mean(losses)
|
tr_loss = np.mean(losses)
|
||||||
|
|
||||||
pbar.set_description(f'training epoch={epoch} '
|
pbar.set_description(f'training epoch={epoch} '
|
||||||
f'loss={tr_loss:.5f} '
|
f'loss={tr_loss:.5f} [neg={np.mean(neg_losses):.5f}, pos={np.mean(pos_losses):.5f}] '
|
||||||
f'val_loss={val_loss:.5f} '
|
f'val_loss={val_loss:.5f} [neg={np.mean(neg_losses_val):.5f}, pos={np.mean(pos_losses_val):.5f}] '
|
||||||
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
||||||
|
|
||||||
# validation
|
# validation
|
||||||
# self.eval()
|
self.eval()
|
||||||
# with torch.no_grad:
|
with torch.no_grad():
|
||||||
# predictions, losses = [], []
|
losses, pos_losses_val, neg_losses_val = [], [], []
|
||||||
# for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
||||||
# phi = self.projector(xi)
|
phi = self.projector(xi)
|
||||||
# contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
contrastive_loss, neg_loss, pos_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||||
#
|
#contrastive_loss = neg_loss + pos_loss
|
||||||
# logits = self.forward(xi)
|
losses.append(contrastive_loss.item())
|
||||||
# loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
neg_losses_val.append(neg_loss.item())
|
||||||
# losses.append(loss.item())
|
pos_losses_val.append(pos_loss.item())
|
||||||
# logits = nn.functional.log_softmax(logits, dim=1)
|
val_loss = np.mean(losses)
|
||||||
# prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
|
||||||
# predictions.append(prediction)
|
|
||||||
# val_loss = np.mean(losses)
|
|
||||||
# predictions = np.concatenate(predictions)
|
|
||||||
# acc = accuracy_score(yval, predictions)
|
|
||||||
# macrof1 = f1_score(yval, predictions, average='macro')
|
|
||||||
# microf1 = f1_score(yval, predictions, average='micro')
|
|
||||||
#
|
|
||||||
# foo.write(f'{epoch}\t{tr_loss:.8f}\t{val_loss:.8f}\t{acc:.3f}\t{macrof1:.3f}\t{microf1:.3f}\n')
|
|
||||||
# foo.flush()
|
|
||||||
|
|
||||||
# early_stop(microf1, epoch)
|
early_stop(val_loss, epoch)
|
||||||
# if early_stop.IMPROVED:
|
if early_stop.IMPROVED:
|
||||||
# torch.save(self.state_dict(), checkpointpath)
|
torch.save(self.state_dict(), checkpointpath)
|
||||||
# elif early_stop.STOP:
|
elif early_stop.STOP:
|
||||||
# break
|
break
|
||||||
|
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
|
||||||
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
|
return early_stop.best_score
|
||||||
|
|
||||||
|
def train_linear_classifier(self, X, y, Xval, yval, batch_size, epochs, patience=10, lr=0.001, log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
|
early_stop = EarlyStop(patience)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
|
optim = torch.optim.Adam(self.ff.parameters(), lr=lr)
|
||||||
|
|
||||||
|
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||||
|
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||||
|
|
||||||
|
tr_loss = val_loss = acc = macrof1 = microf1 = -1
|
||||||
|
pbar = tqdm(range(1, epochs + 1))
|
||||||
|
for epoch in pbar:
|
||||||
|
# training
|
||||||
|
self.train()
|
||||||
|
losses = []
|
||||||
|
for xi, yi in tr_data.asDataLoader(batch_size, shuffle=True):
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
phi = self.projector(xi)
|
||||||
|
logits = self.ff(phi.detach())
|
||||||
|
|
||||||
|
optim.zero_grad()
|
||||||
|
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
losses.append(loss.item())
|
||||||
|
tr_loss = np.mean(losses)
|
||||||
|
pbar.set_description(f'training epoch={epoch} '
|
||||||
|
f'loss={tr_loss:.5f} '
|
||||||
|
f'val_loss={val_loss:.5f} val_acc={acc:.4f} macrof1={macrof1:.4f} microf1={microf1:.4f}'
|
||||||
|
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
||||||
|
|
||||||
|
# validation
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions, losses = [], []
|
||||||
|
for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
||||||
|
logits = self.forward(xi)
|
||||||
|
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||||
|
losses.append(loss.item())
|
||||||
|
logits = nn.functional.log_softmax(logits, dim=1)
|
||||||
|
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||||
|
predictions.append(prediction)
|
||||||
|
val_loss = np.mean(losses)
|
||||||
|
predictions = np.concatenate(predictions)
|
||||||
|
acc = accuracy_score(yval, predictions)
|
||||||
|
macrof1 = f1_score(yval, predictions, average='macro')
|
||||||
|
microf1 = f1_score(yval, predictions, average='micro')
|
||||||
|
|
||||||
|
early_stop(microf1, epoch)
|
||||||
|
if early_stop.IMPROVED:
|
||||||
|
torch.save(self.state_dict(), checkpointpath)
|
||||||
|
elif early_stop.STOP:
|
||||||
|
break
|
||||||
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
|
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
|
||||||
self.load_state_dict(torch.load(checkpointpath))
|
self.load_state_dict(torch.load(checkpointpath))
|
||||||
return early_stop.best_score
|
return early_stop.best_score
|
||||||
|
@ -197,14 +249,25 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.eval()
|
self.eval()
|
||||||
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
||||||
predictions = []
|
predictions = []
|
||||||
with torch.no_grad:
|
with torch.no_grad():
|
||||||
for xi, yi in te_data.asDataLoader(batch_size, shuffle=False):
|
for xi in te_data.asDataLoader(batch_size, shuffle=False):
|
||||||
logits = self.forward(xi)
|
logits = self.forward(xi)
|
||||||
logits = nn.functional.log_softmax(logits, dim=1)
|
logits = nn.functional.log_softmax(logits, dim=1)
|
||||||
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||||
predictions.append(prediction)
|
predictions.append(prediction)
|
||||||
return np.concatenate(predictions)
|
return np.concatenate(predictions)
|
||||||
|
|
||||||
|
def project(self, x, batch_size=100):
|
||||||
|
self.eval()
|
||||||
|
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
||||||
|
predictions = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for xi in te_data.asDataLoader(batch_size, shuffle=False):
|
||||||
|
phi = tensor2numpy(self.projector(xi))
|
||||||
|
predictions.append(phi)
|
||||||
|
return np.concatenate(predictions)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
phi = self.projector(x)
|
phi = self.projector(x)
|
||||||
return self.ff(phi)
|
return self.ff(phi)
|
||||||
|
|
Loading…
Reference in New Issue