testing kta
This commit is contained in:
parent
cc49ffd152
commit
a3732cff1e
20
src/main.py
20
src/main.py
|
@ -9,7 +9,6 @@ import torch
|
||||||
from model.transformations import CNNProjection
|
from model.transformations import CNNProjection
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
hidden_size=32
|
hidden_size=32
|
||||||
channels_out=128
|
channels_out=128
|
||||||
output_size=1024
|
output_size=1024
|
||||||
|
@ -18,12 +17,19 @@ pad_length=3000
|
||||||
batch_size=50
|
batch_size=50
|
||||||
n_epochs=256
|
n_epochs=256
|
||||||
bigrams=False
|
bigrams=False
|
||||||
|
n_authors=-1
|
||||||
|
docs_by_author=-1
|
||||||
|
|
||||||
#hidden_size=16
|
debug=False
|
||||||
#output_size=32
|
if debug:
|
||||||
#pad_length=100
|
print(('*'*20)+' DEBUG MODE ' + ('*'*20))
|
||||||
#batch_size=10
|
hidden_size=16
|
||||||
#n_epochs=20
|
output_size=32
|
||||||
|
pad_length=100
|
||||||
|
batch_size=10
|
||||||
|
n_epochs=20
|
||||||
|
n_authors = 5
|
||||||
|
docs_by_author = 10
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
|
@ -32,7 +38,7 @@ else:
|
||||||
print(f'running on {device}')
|
print(f'running on {device}')
|
||||||
|
|
||||||
#dataset = Victorian(data_path='../../authorship_analysis/data/victoria', n_authors=5, docs_by_author=25)
|
#dataset = Victorian(data_path='../../authorship_analysis/data/victoria', n_authors=5, docs_by_author=25)
|
||||||
dataset = Imdb62(data_path='../../authorship_analysis/data/imdb62/imdb62.txt', n_authors=-1, docs_by_author=-1)
|
dataset = Imdb62(data_path='../../authorship_analysis/data/imdb62/imdb62.txt', n_authors=n_authors, docs_by_author=docs_by_author)
|
||||||
Xtr, ytr = dataset.train.data, dataset.train.target
|
Xtr, ytr = dataset.train.data, dataset.train.target
|
||||||
Xte, yte = dataset.test.data, dataset.test.target
|
Xte, yte = dataset.test.data, dataset.test.target
|
||||||
A = np.unique(ytr)
|
A = np.unique(ytr)
|
||||||
|
|
|
@ -18,8 +18,8 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
def fit(self, X, y, batch_size, epochs, lr=0.001, val_prop=0.1, log='../log/tmp.csv'):
|
||||||
#batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||||
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
|
#batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
|
||||||
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
||||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
@ -33,17 +33,32 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
for epoch in pbar:
|
for epoch in pbar:
|
||||||
# training
|
# training
|
||||||
self.train()
|
self.train()
|
||||||
losses = []
|
losses, attr_losses, sav_losses = [], [], []
|
||||||
for xi, yi in batcher.epoch(X, y):
|
for xi, yi in batcher.epoch(X, y):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
xi = self.padder.transform(xi)
|
xi = self.padder.transform(xi)
|
||||||
logits = self.forward(xi)
|
phi = self.projector(xi)
|
||||||
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
|
||||||
|
logits = self.ff(phi)
|
||||||
|
loss_attr = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||||
|
|
||||||
|
kernel = torch.matmul(phi, phi.T)
|
||||||
|
ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
|
||||||
|
loss_sav = KernelAlignmentLoss(kernel, ideal_kernel)
|
||||||
|
|
||||||
|
loss = loss_attr + loss_sav
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optim.step()
|
optim.step()
|
||||||
|
|
||||||
|
attr_losses.append(loss_attr.item())
|
||||||
|
sav_losses.append(loss_sav.item())
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
tr_loss = np.mean(losses)
|
tr_loss = np.mean(losses)
|
||||||
pbar.set_description(f'training epoch={epoch} loss={tr_loss:.5f} val_loss={val_loss:.5f}')
|
pbar.set_description(f'training epoch={epoch} '
|
||||||
|
f'loss={tr_loss:.5f} '
|
||||||
|
f'attr-loss={np.mean(attr_losses):.5f} '
|
||||||
|
f'sav-loss={np.mean(sav_losses):.5f} val_loss={val_loss:.5f}')
|
||||||
|
|
||||||
# validation
|
# validation
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
Loading…
Reference in New Issue