refactoring
This commit is contained in:
parent
c0c116fd66
commit
1cd9ec251a
11
Notes.txt
11
Notes.txt
|
@ -3,4 +3,13 @@ a) unos mejores que los de Ruder donde hay un layer más de clasificación (o se
|
|||
b) unos "simplified" que son peores que los de Ruder porque he quitado ese layer adicional
|
||||
También vi que se mejoraba con l2(phi(x)) así que lo he dejado así
|
||||
Ahora voy a probar a añadir ese layer adicional como último step in phi(x) <-- ejecutando
|
||||
Luego quiero probar a imponer la regularización en todos los layers antes de la clasificación...
|
||||
Luego quiero probar a imponer la regularización en todos los layers antes de la clasificación...
|
||||
|
||||
Lo de la l2 es un requisito de supervised contrastive learning (SCL)
|
||||
El problema para aplicar SCL es entender qué quiere decir el "crop" en texto, y en particular en AA. Podría simplemente
|
||||
ser equivalente a "fragmento", es decir, que un tipo de inductive bias es que un fragmento de un texto de un autor
|
||||
debe tener una representación similar a otro fragmento del mismo texto. Hay que entender bien cómo generarlos,
|
||||
de forma que los fragmentos sean caracterizantes (esto quiere decir probablemente imponer una cierta extensión).
|
||||
También hay que entender cómo tratar los solapamientos entre fragmentos.
|
||||
|
||||
Una idea de título sería: "AA is to Classification as SCL is to SAV", or AA = Classif - SCL + SAV
|
30
TODO.txt
30
TODO.txt
|
@ -1,3 +1,31 @@
|
|||
Recap Feb. 2021:
|
||||
- Adapt everything to testing a classic neural training for AA (i.e., projector+classifier training) vs. applying Supervised
|
||||
Contrastive Learning (SCL) as a pretraining step for solving SAV, and then training a linear classifier with
|
||||
the projector network frozen. Reassess the work in terms of SAV and made connections with KTA and SVM. Maybe claim
|
||||
that SCL+SVM is the way to go.
|
||||
- Compare (Attribution):
|
||||
- S.Ruder systems
|
||||
- My system (projector+classifier layer) as a reimplementation of S.Ruder's systems
|
||||
- Projector trained via SCL + Classifier layer trained alone.
|
||||
- Projector trained via SCL + SVM Classifier.
|
||||
- Projector trained via KTA + SVM Classifier.
|
||||
- Compare (SAV):
|
||||
- My system (projector+binary-classifier layer)
|
||||
- Projector trained via SCL + Binary Classifier layer trained alone.
|
||||
- Projector trained via SCL + SVM Classifier.
|
||||
- Projector trained via KTA + SVM Classifier.
|
||||
- Other systems (maybe Diff-Vectors, maybe Impostors, maybe distance-based)
|
||||
- Additional experiments:
|
||||
- show the kernel matrix
|
||||
|
||||
Future:
|
||||
- Test also in general TC? there are some torch datasets in torchtext that could simplify things... but that would
|
||||
blur the idea of SCL-SAV
|
||||
|
||||
Code:
|
||||
- redo dataset in terms of pytorch's data_loader
|
||||
|
||||
---------------------
|
||||
Things to clarify:
|
||||
|
||||
about the network:
|
||||
|
@ -23,4 +51,6 @@ maybe I have to review the validation of the sav-loss; since it is batched, it m
|
|||
SAV: how should the range of k(xi,xj) be interpreted? how to decide for value threshold for returning -1 or +1?
|
||||
I guess the best thing to do is to learn a simple threshold, one feed forward 1-to-1
|
||||
|
||||
plot the kernel matrix as an imshow, with rows/cols arranged by authors, and check whether the KTA that SCL yields
|
||||
is better than that obtained using a traditional training for attribution.
|
||||
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
#!/bin/bash
|
||||
conda activate torch
|
||||
|
||||
dataset=enron
|
||||
for authors in 10 50 ; do
|
||||
for alpha in 1 0.999 0.99 0.9 0.5 ; do
|
||||
python main.py --dataset $dataset -A $authors -s 0 -o ../results_$dataset.csv --alpha $alpha
|
||||
done
|
||||
done
|
||||
|
||||
dataset=imdb62
|
||||
for alpha in 1 0.999 0.99 0.9 0.5 ; do
|
||||
python main.py --dataset $dataset -A -1 -s 0 -o ../results_$dataset.csv --alpha $alpha
|
||||
done
|
|
@ -9,17 +9,13 @@ import pickle
|
|||
class LabelledCorpus:
|
||||
|
||||
def __init__(self, documents, labels):
|
||||
if not isinstance(documents, np.ndarray): documents = np.asarray(documents, dtype=str)
|
||||
if not isinstance(labels, np.ndarray): labels = np.asarray(labels)
|
||||
if not isinstance(documents, np.ndarray):
|
||||
documents = np.asarray(documents, dtype=object) #dtype=str occupies too much in memory and is not needed
|
||||
if not isinstance(labels, np.ndarray):
|
||||
labels = np.asarray(labels)
|
||||
self.data = documents
|
||||
self.target = labels
|
||||
|
||||
def _tolist(self):
|
||||
self.data = self.data.tolist()
|
||||
|
||||
def _toarray(self):
|
||||
self.data = np.asarray(self.data, dtype=str)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
@ -41,17 +37,11 @@ class AuthorshipDataset(ABC):
|
|||
if pickle_path and os.path.exists(pickle_path):
|
||||
print(f'loading dataset image in {pickle_path}')
|
||||
dataset = pickle.load(open(pickle_path, 'rb'))
|
||||
dataset.train._toarray()
|
||||
dataset.test._toarray()
|
||||
else:
|
||||
dataset = loader(**kwargs)
|
||||
if pickle_path:
|
||||
print(f'dumping dataset in {pickle_path} for faster load')
|
||||
dataset.train._tolist()
|
||||
dataset.test._tolist()
|
||||
pickle.dump(dataset, open(pickle_path, 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
dataset.train._toarray()
|
||||
dataset.test._toarray()
|
||||
return dataset
|
||||
|
||||
def __init__(self, data_path, n_authors=-1, docs_by_author=-1, n_open_set_authors=0, random_state=42):
|
||||
|
@ -62,13 +52,9 @@ class AuthorshipDataset(ABC):
|
|||
np.random.seed(random_state)
|
||||
|
||||
self._check_n_authors(n_authors, n_open_set_authors)
|
||||
|
||||
self.train, self.test, self.target_names = self._fetch_and_split()
|
||||
|
||||
self._assure_docs_by_author(docs_by_author)
|
||||
|
||||
self._reduce_authors_documents(n_authors, docs_by_author, n_open_set_authors)
|
||||
|
||||
self._remove_label_gaps()
|
||||
|
||||
super().__init__()
|
||||
|
|
|
@ -18,7 +18,7 @@ class Imdb62(AuthorshipDataset):
|
|||
def _fetch_and_split(self):
|
||||
file = open(self.data_path,'rt', encoding= "utf-8").readlines()
|
||||
splits = [line.split('\t') for line in file]
|
||||
reviews = np.asarray([split[4]+' '+split[5] for split in splits])
|
||||
reviews = [split[4]+' '+split[5] for split in splits]
|
||||
|
||||
authors=[]
|
||||
authors_ids = dict()
|
||||
|
|
|
@ -19,7 +19,6 @@ class Victorian(AuthorshipDataset):
|
|||
csv_reader = csv.reader(file, delimiter = ',')
|
||||
next(csv_reader)
|
||||
for row in csv_reader:
|
||||
# if row[0]!='text':
|
||||
data.append(row[0])
|
||||
labels.append(int(row[1]))
|
||||
|
||||
|
|
50
src/main.py
50
src/main.py
|
@ -5,7 +5,7 @@ from data.fetch_blogs import Blogs
|
|||
from data.fetch_imdb62 import Imdb62
|
||||
from data.fetch_enron_mail import EnronMail
|
||||
from index import Index
|
||||
from model.classifiers import AuthorshipAttributionClassifier, SameAuthorClassifier, FullAuthorClassifier
|
||||
from model.classifiers import AuthorshipAttributionClassifier #, SameAuthorClassifier, FullAuthorClassifier
|
||||
from data.fetch_victorian import Victorian
|
||||
from evaluation import evaluation
|
||||
import torch
|
||||
|
@ -16,11 +16,7 @@ import os
|
|||
import sys
|
||||
|
||||
|
||||
def main(opt):
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
print(f'running on {device}')
|
||||
|
||||
def load_dataset(opt):
|
||||
# dataset load
|
||||
if opt.dataset == 'enron':
|
||||
loader = EnronMail
|
||||
|
@ -39,13 +35,24 @@ def main(opt):
|
|||
pickle_path = None
|
||||
if opt.pickle:
|
||||
pickle_path = f'{opt.pickle}/{dataset_name}.pickle'
|
||||
dataset = AuthorshipDataset.load(loader,
|
||||
pickle_path=pickle_path,
|
||||
data_path=data_path,
|
||||
n_authors=opt.authors,
|
||||
docs_by_author=opt.documents,
|
||||
random_state=opt.seed
|
||||
)
|
||||
dataset = AuthorshipDataset.load(
|
||||
loader,
|
||||
pickle_path=pickle_path,
|
||||
data_path=data_path,
|
||||
n_authors=opt.authors,
|
||||
docs_by_author=opt.documents,
|
||||
random_state=opt.seed
|
||||
)
|
||||
return dataset_name, dataset
|
||||
|
||||
|
||||
|
||||
def main(opt):
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
print(f'running on {device}')
|
||||
|
||||
dataset_name, dataset = load_dataset(opt)
|
||||
|
||||
# dataset indexing
|
||||
Xtr, ytr = dataset.train.data, dataset.train.target
|
||||
|
@ -61,12 +68,6 @@ def main(opt):
|
|||
pad_index = index.add_word('PADTOKEN')
|
||||
print(f'vocabulary size={index.vocabulary_size()}')
|
||||
|
||||
#shuffle1 = np.random.permutation(Xte.shape[0])
|
||||
#shuffle2 = np.random.permutation(Xte.shape[0])
|
||||
#x1, y1 = Xte[shuffle1], yte[shuffle1]
|
||||
#x2, y2 = Xte[shuffle2], yte[shuffle2]
|
||||
#paired_y = y1==y2
|
||||
|
||||
# attribution
|
||||
print('Attribution')
|
||||
phi = Phi(
|
||||
|
@ -93,12 +94,19 @@ def main(opt):
|
|||
else:
|
||||
method = opt.name
|
||||
|
||||
cls.supervised_contrastive_learning(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
# train
|
||||
val_microf1 = cls.fit(Xtr, ytr,
|
||||
batch_size=opt.batchsize, epochs=opt.epochs, alpha=opt.alpha, lr=opt.lr,
|
||||
log=f'{opt.log}/{method}-{dataset_name}.csv',
|
||||
checkpointpath=opt.checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
# test
|
||||
yte_ = cls.predict(Xte)
|
||||
|
@ -154,7 +162,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-e', '--epochs', help='Max number of epochs', type=int, default=250)
|
||||
parser.add_argument('-A', '--authors', help='Number of authors (-1 to select all)', type=int, default=-1)
|
||||
parser.add_argument('-D', '--documents', help='Number of documents per author (-1 to select all)', type=int, default=-1)
|
||||
parser.add_argument('-s', '--seed', help='Random seed', type=int, default=-1)
|
||||
parser.add_argument('-s', '--seed', help='Random seed', type=int, default=0)
|
||||
parser.add_argument('-o', '--output', help='File where to write test results', default='../results.csv')
|
||||
parser.add_argument('-l', '--log', help='Log dir where to output training an validation losses', default='../log')
|
||||
parser.add_argument('-P', '--pickle', help='If specified, pickles a copy of the dataset for faster reload. '
|
||||
|
|
|
@ -6,8 +6,11 @@ from sklearn.metrics import accuracy_score, f1_score
|
|||
from tqdm import tqdm
|
||||
import math
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from losses import SupConLoss1View
|
||||
from model.early_stop import EarlyStop
|
||||
from model.layers import FFProjection
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class AuthorshipAttributionClassifier(nn.Module):
|
||||
|
@ -17,33 +20,35 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.ff = FFProjection(input_size=projector.output_size,
|
||||
hidden_sizes=[],
|
||||
output_size=num_authors).to(device)
|
||||
self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||
self.pad_index = pad_index
|
||||
self.pad_length = pad_length
|
||||
self.device = device
|
||||
|
||||
def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
||||
early_stop = EarlyStop(patience)
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=epochs)
|
||||
|
||||
#batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
|
||||
batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
|
||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
savcriterion = torch.nn.BCEWithLogitsLoss().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
|
||||
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
||||
|
||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||
|
||||
with open(log, 'wt') as foo:
|
||||
print()
|
||||
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
||||
tr_loss, val_loss = -1, -1
|
||||
pbar = tqdm(range(1, batcher.n_epochs+1))
|
||||
pbar = tqdm(range(1, epochs + 1))
|
||||
for epoch in pbar:
|
||||
# training
|
||||
self.train()
|
||||
losses, attr_losses, sav_losses = [], [], []
|
||||
for xi, yi in batcher.epoch(X, y):
|
||||
for xi, yi in tr_data.asDataLoader(batch_size, shuffle=True):
|
||||
optim.zero_grad()
|
||||
xi = self.padder.transform(xi)
|
||||
phi = self.projector(xi)
|
||||
|
||||
loss_attr = loss_sav = 0
|
||||
|
@ -93,23 +98,25 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
|
||||
# validation
|
||||
self.eval()
|
||||
predictions, losses = [], []
|
||||
for xi, yi in batcher_val.epoch(Xval, yval):
|
||||
xi = self.padder.transform(xi)
|
||||
logits = self.forward(xi)
|
||||
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||
losses.append(loss.item())
|
||||
logits = nn.functional.log_softmax(logits, dim=1)
|
||||
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||
predictions.append(prediction)
|
||||
val_loss = np.mean(losses)
|
||||
predictions = np.concatenate(predictions)
|
||||
acc = accuracy_score(yval, predictions)
|
||||
macrof1 = f1_score(yval, predictions, average='macro')
|
||||
microf1 = f1_score(yval, predictions, average='micro')
|
||||
with torch.no_grad:
|
||||
predictions, losses = [], []
|
||||
# for xi, yi in batcher_val.epoch(Xval, yval):
|
||||
for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
||||
# xi = self.padder.transform(xi)
|
||||
logits = self.forward(xi)
|
||||
loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||
losses.append(loss.item())
|
||||
logits = nn.functional.log_softmax(logits, dim=1)
|
||||
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||
predictions.append(prediction)
|
||||
val_loss = np.mean(losses)
|
||||
predictions = np.concatenate(predictions)
|
||||
acc = accuracy_score(yval, predictions)
|
||||
macrof1 = f1_score(yval, predictions, average='macro')
|
||||
microf1 = f1_score(yval, predictions, average='micro')
|
||||
|
||||
foo.write(f'{epoch}\t{tr_loss:.8f}\t{val_loss:.8f}\t{acc:.3f}\t{macrof1:.3f}\t{microf1:.3f}\n')
|
||||
foo.flush()
|
||||
foo.write(f'{epoch}\t{tr_loss:.8f}\t{val_loss:.8f}\t{acc:.3f}\t{macrof1:.3f}\t{microf1:.3f}\n')
|
||||
foo.flush()
|
||||
|
||||
early_stop(microf1, epoch)
|
||||
if early_stop.IMPROVED:
|
||||
|
@ -120,16 +127,82 @@ class AuthorshipAttributionClassifier(nn.Module):
|
|||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def supervised_contrastive_learning(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
||||
early_stop = EarlyStop(patience)
|
||||
|
||||
criterion = SupConLoss1View().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
|
||||
X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)
|
||||
|
||||
tr_data = IndexedDataset(X, y, self.pad_length, self.pad_index, self.device)
|
||||
val_data = IndexedDataset(Xval, yval, self.pad_length, self.pad_index, self.device)
|
||||
|
||||
with open(log, 'wt') as foo:
|
||||
print()
|
||||
foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
|
||||
tr_loss, val_loss = -1, -1
|
||||
pbar = tqdm(range(1, epochs + 1))
|
||||
for epoch in pbar:
|
||||
# training
|
||||
self.train()
|
||||
losses = []
|
||||
for xi, yi in tr_data.asDataLoader(batch_size, shuffle=True):
|
||||
optim.zero_grad()
|
||||
phi = self.projector(xi)
|
||||
contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||
contrastive_loss.backward()
|
||||
optim.step()
|
||||
losses.append(contrastive_loss.item())
|
||||
tr_loss = np.mean(losses)
|
||||
pbar.set_description(f'training epoch={epoch} '
|
||||
f'loss={tr_loss:.5f} '
|
||||
f'val_loss={val_loss:.5f} '
|
||||
f'patience={early_stop.patience}/{early_stop.patience_limit}')
|
||||
|
||||
# validation
|
||||
# self.eval()
|
||||
# with torch.no_grad:
|
||||
# predictions, losses = [], []
|
||||
# for xi, yi in val_data.asDataLoader(batch_size, shuffle=False):
|
||||
# phi = self.projector(xi)
|
||||
# contrastive_loss = criterion(phi, torch.as_tensor(yi).to(self.device))
|
||||
#
|
||||
# logits = self.forward(xi)
|
||||
# loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||
# losses.append(loss.item())
|
||||
# logits = nn.functional.log_softmax(logits, dim=1)
|
||||
# prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||
# predictions.append(prediction)
|
||||
# val_loss = np.mean(losses)
|
||||
# predictions = np.concatenate(predictions)
|
||||
# acc = accuracy_score(yval, predictions)
|
||||
# macrof1 = f1_score(yval, predictions, average='macro')
|
||||
# microf1 = f1_score(yval, predictions, average='micro')
|
||||
#
|
||||
# foo.write(f'{epoch}\t{tr_loss:.8f}\t{val_loss:.8f}\t{acc:.3f}\t{macrof1:.3f}\t{microf1:.3f}\n')
|
||||
# foo.flush()
|
||||
|
||||
# early_stop(microf1, epoch)
|
||||
# if early_stop.IMPROVED:
|
||||
# torch.save(self.state_dict(), checkpointpath)
|
||||
# elif early_stop.STOP:
|
||||
# break
|
||||
print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
|
||||
self.load_state_dict(torch.load(checkpointpath))
|
||||
return early_stop.best_score
|
||||
|
||||
def predict(self, x, batch_size=100):
|
||||
self.eval()
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
te_data = IndexedDataset(x, None, self.pad_length, self.pad_index, self.device)
|
||||
predictions = []
|
||||
for xi in tqdm(batcher.epoch(x), desc='test'):
|
||||
xi = self.padder.transform(xi)
|
||||
logits = self.forward(xi)
|
||||
logits = nn.functional.log_softmax(logits, dim=1)
|
||||
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||
predictions.append(prediction)
|
||||
with torch.no_grad:
|
||||
for xi, yi in te_data.asDataLoader(batch_size, shuffle=False):
|
||||
logits = self.forward(xi)
|
||||
logits = nn.functional.log_softmax(logits, dim=1)
|
||||
prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
|
||||
predictions.append(prediction)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -168,134 +241,133 @@ def choose_sav_pairs(y, npairs):
|
|||
|
||||
|
||||
|
||||
class SameAuthorClassifier(nn.Module):
|
||||
def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
|
||||
super(SameAuthorClassifier, self).__init__()
|
||||
self.projector = projector.to(device)
|
||||
self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||
self.device = device
|
||||
|
||||
def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
|
||||
self.train()
|
||||
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
|
||||
pbar = tqdm(range(batcher.n_epochs))
|
||||
for epoch in pbar:
|
||||
losses = []
|
||||
for xi, yi in batcher.epoch(X, y):
|
||||
optim.zero_grad()
|
||||
xi = self.padder.transform(xi)
|
||||
phi = self.projector(xi)
|
||||
#normalize phi to have norm 1? maybe better as the last step of projector
|
||||
kernel = torch.matmul(phi, phi.T)
|
||||
ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
|
||||
loss = KernelAlignmentLoss(kernel, ideal_kernel)
|
||||
loss.backward()
|
||||
#clip_gradient(model)
|
||||
optim.step()
|
||||
losses.append(loss.item())
|
||||
pbar.set_description(f'training epoch={epoch} loss={np.mean(losses):.5f}')
|
||||
|
||||
def predict(self, x, z, batch_size=100):
|
||||
self.eval()
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
predictions = []
|
||||
for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
|
||||
xi = self.padder.transform(xi)
|
||||
zi = self.padder.transform(zi)
|
||||
inners = self.forward(xi, zi)
|
||||
prediction = tensor2numpy(inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
|
||||
predictions.append(prediction)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
def forward(self, x, z):
|
||||
assert x.shape == z.shape, 'shape mismatch between matrices x and z'
|
||||
phi_x = self.projector(x)
|
||||
phi_z = self.projector(z)
|
||||
rows, cols = phi_x.shape
|
||||
pairwise_inners = torch.bmm(phi_x.view(rows, 1, cols), phi_z.view(rows, cols, 1)).squeeze()
|
||||
return pairwise_inners
|
||||
# class SameAuthorClassifier(nn.Module):
|
||||
# def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
|
||||
# super(SameAuthorClassifier, self).__init__()
|
||||
# self.projector = projector.to(device)
|
||||
# self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||
# self.device = device
|
||||
#
|
||||
# def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
|
||||
# self.train()
|
||||
# batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
|
||||
# optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
#
|
||||
# pbar = tqdm(range(batcher.n_epochs))
|
||||
# for epoch in pbar:
|
||||
# losses = []
|
||||
# for xi, yi in batcher.epoch(X, y):
|
||||
# optim.zero_grad()
|
||||
# xi = self.padder.transform(xi)
|
||||
# phi = self.projector(xi)
|
||||
# #normalize phi to have norm 1? maybe better as the last step of projector
|
||||
# kernel = torch.matmul(phi, phi.T)
|
||||
# ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
|
||||
# loss = KernelAlignmentLoss(kernel, ideal_kernel)
|
||||
# loss.backward()
|
||||
# #clip_gradient(model)
|
||||
# optim.step()
|
||||
# losses.append(loss.item())
|
||||
# pbar.set_description(f'training epoch={epoch} loss={np.mean(losses):.5f}')
|
||||
#
|
||||
# def predict(self, x, z, batch_size=100):
|
||||
# self.eval()
|
||||
# batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
# predictions = []
|
||||
# for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
|
||||
# xi = self.padder.transform(xi)
|
||||
# zi = self.padder.transform(zi)
|
||||
# inners = self.forward(xi, zi)
|
||||
# prediction = tensor2numpy(inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
|
||||
# predictions.append(prediction)
|
||||
# return np.concatenate(predictions)
|
||||
#
|
||||
# def forward(self, x, z):
|
||||
# assert x.shape == z.shape, 'shape mismatch between matrices x and z'
|
||||
# phi_x = self.projector(x)
|
||||
# phi_z = self.projector(z)
|
||||
# rows, cols = phi_x.shape
|
||||
# pairwise_inners = torch.bmm(phi_x.view(rows, 1, cols), phi_z.view(rows, cols, 1)).squeeze()
|
||||
# return pairwise_inners
|
||||
|
||||
|
||||
class FullAuthorClassifier(nn.Module):
|
||||
def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
|
||||
super(FullAuthorClassifier, self).__init__()
|
||||
self.projector = projector.to(device)
|
||||
self.ff = FFProjection(input_size=projector.space_dimensions(),
|
||||
hidden_sizes=[1024],
|
||||
output_size=num_authors).to(device)
|
||||
self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||
self.device = device
|
||||
|
||||
def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
|
||||
self.train()
|
||||
batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
|
||||
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
alpha = 0.5
|
||||
|
||||
pbar = tqdm(range(batcher.n_epochs))
|
||||
for epoch in pbar:
|
||||
losses, sav_losses, attr_losses = [], [], []
|
||||
for xi, yi in batcher.epoch(X, y):
|
||||
optim.zero_grad()
|
||||
xi = self.padder.transform(xi)
|
||||
phi = self.projector(xi)
|
||||
#normalize phi to have norm 1? maybe better as the last step of projector
|
||||
|
||||
#sav-loss
|
||||
kernel = torch.matmul(phi, phi.T)
|
||||
ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
|
||||
sav_loss = KernelAlignmentLoss(kernel, ideal_kernel)
|
||||
sav_losses.append(sav_loss.item())
|
||||
|
||||
#attr-loss
|
||||
logits = self.ff(phi)
|
||||
attr_loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||
attr_losses.append(attr_loss.item())
|
||||
|
||||
#loss
|
||||
loss = (alpha)*sav_loss + (1-alpha)*attr_loss
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
#clip_gradient(model)
|
||||
optim.step()
|
||||
pbar.set_description(
|
||||
f'training epoch={epoch} '
|
||||
f'sav-loss={np.mean(sav_losses):.5f} '
|
||||
f'attr-loss={np.mean(attr_losses):.5f} '
|
||||
f'loss={np.mean(losses):.5f}'
|
||||
)
|
||||
|
||||
def predict_sav(self, x, z, batch_size=100):
|
||||
self.eval()
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
predictions = []
|
||||
for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
|
||||
xi = self.padder.transform(xi)
|
||||
zi = self.padder.transform(zi)
|
||||
phi_xi = self.projector(xi)
|
||||
phi_zi = self.projector(zi)
|
||||
rows, cols = phi_xi.shape
|
||||
pairwise_inners = torch.bmm(phi_xi.view(rows, 1, cols), phi_zi.view(rows, cols, 1)).squeeze()
|
||||
prediction = tensor2numpy(pairwise_inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
|
||||
predictions.append(prediction)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
def predict_labels(self, x, batch_size=100):
|
||||
self.eval()
|
||||
batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
predictions = []
|
||||
for xi in tqdm(batcher.epoch(x), desc='test'):
|
||||
xi = self.padder.transform(xi)
|
||||
phi = self.projector(xi)
|
||||
logits = self.ff(phi)
|
||||
prediction = tensor2numpy( torch.argmax(logits, dim=1).view(-1))
|
||||
predictions.append(prediction)
|
||||
return np.concatenate(predictions)
|
||||
|
||||
# class FullAuthorClassifier(nn.Module):
|
||||
# def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
|
||||
# super(FullAuthorClassifier, self).__init__()
|
||||
# self.projector = projector.to(device)
|
||||
# self.ff = FFProjection(input_size=projector.space_dimensions(),
|
||||
# hidden_sizes=[1024],
|
||||
# output_size=num_authors).to(device)
|
||||
# self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
|
||||
# self.device = device
|
||||
#
|
||||
# def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
|
||||
# self.train()
|
||||
# batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
|
||||
# criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
||||
# optim = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
# alpha = 0.5
|
||||
#
|
||||
# pbar = tqdm(range(batcher.n_epochs))
|
||||
# for epoch in pbar:
|
||||
# losses, sav_losses, attr_losses = [], [], []
|
||||
# for xi, yi in batcher.epoch(X, y):
|
||||
# optim.zero_grad()
|
||||
# xi = self.padder.transform(xi)
|
||||
# phi = self.projector(xi)
|
||||
# #normalize phi to have norm 1? maybe better as the last step of projector
|
||||
#
|
||||
# #sav-loss
|
||||
# kernel = torch.matmul(phi, phi.T)
|
||||
# ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
|
||||
# sav_loss = KernelAlignmentLoss(kernel, ideal_kernel)
|
||||
# sav_losses.append(sav_loss.item())
|
||||
#
|
||||
# #attr-loss
|
||||
# logits = self.ff(phi)
|
||||
# attr_loss = criterion(logits, torch.as_tensor(yi).to(self.device))
|
||||
# attr_losses.append(attr_loss.item())
|
||||
#
|
||||
# #loss
|
||||
# loss = (alpha)*sav_loss + (1-alpha)*attr_loss
|
||||
# losses.append(loss.item())
|
||||
#
|
||||
# loss.backward()
|
||||
# #clip_gradient(model)
|
||||
# optim.step()
|
||||
# pbar.set_description(
|
||||
# f'training epoch={epoch} '
|
||||
# f'sav-loss={np.mean(sav_losses):.5f} '
|
||||
# f'attr-loss={np.mean(attr_losses):.5f} '
|
||||
# f'loss={np.mean(losses):.5f}'
|
||||
# )
|
||||
#
|
||||
# def predict_sav(self, x, z, batch_size=100):
|
||||
# self.eval()
|
||||
# batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
# predictions = []
|
||||
# for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
|
||||
# xi = self.padder.transform(xi)
|
||||
# zi = self.padder.transform(zi)
|
||||
# phi_xi = self.projector(xi)
|
||||
# phi_zi = self.projector(zi)
|
||||
# rows, cols = phi_xi.shape
|
||||
# pairwise_inners = torch.bmm(phi_xi.view(rows, 1, cols), phi_zi.view(rows, cols, 1)).squeeze()
|
||||
# prediction = tensor2numpy(pairwise_inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
|
||||
# predictions.append(prediction)
|
||||
# return np.concatenate(predictions)
|
||||
#
|
||||
# def predict_labels(self, x, batch_size=100):
|
||||
# self.eval()
|
||||
# batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
|
||||
# predictions = []
|
||||
# for xi in tqdm(batcher.epoch(x), desc='test'):
|
||||
# xi = self.padder.transform(xi)
|
||||
# phi = self.projector(xi)
|
||||
# logits = self.ff(phi)
|
||||
# prediction = tensor2numpy( torch.argmax(logits, dim=1).view(-1))
|
||||
# predictions.append(prediction)
|
||||
# return np.concatenate(predictions)
|
||||
|
||||
#def KernelAlignmentLoss(K, Y):
|
||||
# n_el = K.shape[0]*K.shape[1]
|
||||
|
@ -304,92 +376,89 @@ class FullAuthorClassifier(nn.Module):
|
|||
# return loss
|
||||
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(self, batch_size, n_epochs=1, shuffle=True):
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.shuffle = shuffle
|
||||
self.current_epoch = 0
|
||||
|
||||
def epoch(self, *args):
|
||||
lengths = list(map(len, args))
|
||||
assert max(lengths) == min(lengths), 'inconsistent sizes in args'
|
||||
n_batches = math.ceil(lengths[0] / self.batch_size)
|
||||
offset = 0
|
||||
if self.shuffle:
|
||||
index = np.random.permutation(len(args[0]))
|
||||
args = [arg[index] for arg in args]
|
||||
for b in range(n_batches):
|
||||
batch_idx = slice(offset, offset+self.batch_size)
|
||||
batch = [arg[batch_idx] for arg in args]
|
||||
yield batch if len(batch) > 1 else batch[0]
|
||||
offset += self.batch_size
|
||||
self.current_epoch += 1
|
||||
|
||||
|
||||
class TwoClassBatch:
|
||||
"""
|
||||
given a X and y (multi-label) produces batches of elements of X, y for two classes (e.g., c1, c2)
|
||||
of equal size, i.e., the batch is [(x1,c1), ..., (xn,c1), (xn+1,c2), ..., (x2n,c2)]
|
||||
"""
|
||||
def __init__(self, batch_size, n_epochs, steps_per_epoch):
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.current_epoch = 0
|
||||
if self.batch_size % 2 != 0:
|
||||
raise ValueError('warning, batch size is not even')
|
||||
|
||||
def epoch(self, X, y):
|
||||
n_el = len(y)
|
||||
assert X.shape[0] == n_el, 'inconsistent sizes in X, y'
|
||||
classes = np.unique(y)
|
||||
groups = {ci: X[y==ci] for ci in classes}
|
||||
class_prevalences = [len(groups[ci])/n_el for ci in classes]
|
||||
n_choices = self.batch_size // 2
|
||||
|
||||
for b in range(self.steps_per_epoch):
|
||||
class1, class2 = np.random.choice(classes, p=class_prevalences, size=2, replace=False)
|
||||
X1 = np.random.choice(groups[class1], size=n_choices)
|
||||
X2 = np.random.choice(groups[class2], size=n_choices)
|
||||
X_batch = np.concatenate([X1,X2])
|
||||
y_batch = np.repeat([class1, class2], repeats=[n_choices,n_choices])
|
||||
yield X_batch, y_batch
|
||||
self.current_epoch += 1
|
||||
|
||||
|
||||
class Padding:
|
||||
def __init__(self, pad_index, max_length, dynamic=True, pad_at_end=True, device='cpu'):
|
||||
"""
|
||||
:param pad_index: the index representing the PAD token
|
||||
:param max_length: the length that defines the padding
|
||||
:param dynamic: if True (default) pads at min(max_length, max_local_length) where max_local_length is the
|
||||
length of the longest example
|
||||
:param pad_at_end: if True, the pad tokens are added at the end of the lists, if otherwise they are added
|
||||
at the beginning
|
||||
"""
|
||||
self.pad = pad_index
|
||||
self.max_length = max_length
|
||||
self.dynamic = dynamic
|
||||
self.pad_at_end = pad_at_end
|
||||
self.device = device
|
||||
|
||||
def transform(self, X):
|
||||
"""
|
||||
:param X: a list of lists of indexes (integers)
|
||||
:return: a ndarray of shape (n,m) where n is the number of elements in X and m is the pad length (the maximum
|
||||
in elements of X if dynamic, or self.max_length if otherwise)
|
||||
"""
|
||||
X = [x[:self.max_length] for x in X]
|
||||
lengths = list(map(len, X))
|
||||
pad_length = min(max(lengths), self.max_length) if self.dynamic else self.max_length
|
||||
if self.pad_at_end:
|
||||
padded = [x + [self.pad] * (pad_length - x_len) for x, x_len in zip(X, lengths)]
|
||||
else:
|
||||
padded = [[self.pad] * (pad_length - x_len) + x for x, x_len in zip(X, lengths)]
|
||||
return torch.from_numpy(np.asarray(padded, dtype=int)).to(self.device)
|
||||
# class TwoClassBatch:
|
||||
# """
|
||||
# given a X and y (multi-label) produces batches of elements of X, y for two classes (e.g., c1, c2)
|
||||
# of equal size, i.e., the batch is [(x1,c1), ..., (xn,c1), (xn+1,c2), ..., (x2n,c2)]
|
||||
# """
|
||||
# def __init__(self, batch_size, n_epochs, steps_per_epoch):
|
||||
# self.batch_size = batch_size
|
||||
# self.n_epochs = n_epochs
|
||||
# self.steps_per_epoch = steps_per_epoch
|
||||
# self.current_epoch = 0
|
||||
# if self.batch_size % 2 != 0:
|
||||
# raise ValueError('warning, batch size is not even')
|
||||
#
|
||||
# def epoch(self, X, y):
|
||||
# n_el = len(y)
|
||||
# assert X.shape[0] == n_el, 'inconsistent sizes in X, y'
|
||||
# classes = np.unique(y)
|
||||
# groups = {ci: X[y==ci] for ci in classes}
|
||||
# class_prevalences = [len(groups[ci])/n_el for ci in classes]
|
||||
# n_choices = self.batch_size // 2
|
||||
#
|
||||
# for b in range(self.steps_per_epoch):
|
||||
# class1, class2 = np.random.choice(classes, p=class_prevalences, size=2, replace=False)
|
||||
# X1 = np.random.choice(groups[class1], size=n_choices)
|
||||
# X2 = np.random.choice(groups[class2], size=n_choices)
|
||||
# X_batch = np.concatenate([X1,X2])
|
||||
# y_batch = np.repeat([class1, class2], repeats=[n_choices,n_choices])
|
||||
# yield X_batch, y_batch
|
||||
# self.current_epoch += 1
|
||||
|
||||
|
||||
def tensor2numpy(t):
|
||||
return t.to('cpu').detach().numpy()
|
||||
return t.to('cpu').detach().numpy()
|
||||
|
||||
|
||||
# ------------
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, X, y, MAX_LENGTH, padindex, device, pad_at_end=False):
|
||||
self.X = X
|
||||
self.y = y
|
||||
self.MAX_LENGTH = MAX_LENGTH
|
||||
self.padindex = padindex
|
||||
self.device = device
|
||||
self.pad_at_end = pad_at_end
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
@property
|
||||
def islabelled(self):
|
||||
return self.y is not None
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.islabelled:
|
||||
return self.X[index], self.y[index]
|
||||
else:
|
||||
return self.X[index]
|
||||
|
||||
def collate_pad_fn(self, batch):
|
||||
"""
|
||||
:param batch: a list of lists of indexes (integers)
|
||||
:return: a torch.tensor of shape (n,m) where n is the number of elements in X_batch and m is the pad length
|
||||
(the maximum in elements of X_batch)
|
||||
"""
|
||||
if self.islabelled:
|
||||
X, y = list(zip(*batch))
|
||||
else:
|
||||
X = batch
|
||||
lengths = list(map(len, X))
|
||||
pad_length = min(max(lengths), self.MAX_LENGTH)
|
||||
X = [x[:pad_length] for x in X]
|
||||
if self.pad_at_end:
|
||||
padded = [x + [self.padindex] * (pad_length - x_len) for x, x_len in zip(X, lengths)]
|
||||
else:
|
||||
padded = [[self.padindex] * (pad_length - x_len) + x for x, x_len in zip(X, lengths)]
|
||||
|
||||
X = torch.from_numpy(np.asarray(padded, dtype=int)).to(self.device)
|
||||
if self.islabelled:
|
||||
y = torch.from_numpy(np.asarray(y)).to(self.device)
|
||||
return X, y
|
||||
else:
|
||||
return X
|
||||
|
||||
def asDataLoader(self, batch_size, shuffle):
|
||||
return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_pad_fn)
|
||||
|
|
|
@ -70,36 +70,36 @@ class FFProjection(nn.Module):
|
|||
|
||||
|
||||
# deprecated
|
||||
class RNNProjection(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size, output_size, device='cpu'):
|
||||
super(RNNProjection, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
self.num_layers=1
|
||||
self.num_directions=1
|
||||
self.device = device
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size).to(device)
|
||||
self.rnn = nn.GRU(
|
||||
input_size=hidden_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
bidirectional=(self.num_directions == 2),
|
||||
batch_first=True
|
||||
).to(device)
|
||||
self.projection = nn.Linear(self.num_layers * self.num_directions * self.hidden_size, output_size).to(device)
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
return torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(self.device)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
x = self.embedding(x)
|
||||
output, hn = self.rnn(x, self.init_hidden(batch_size))
|
||||
hn = hn.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
|
||||
hn = hn.permute(2, 0, 1, 3).reshape(batch_size, -1)
|
||||
return self.projection(hn)
|
||||
|
||||
def space_dimensions(self):
|
||||
return self.output_size
|
||||
# class RNNProjection(nn.Module):
|
||||
# def __init__(self, vocab_size, hidden_size, output_size, device='cpu'):
|
||||
# super(RNNProjection, self).__init__()
|
||||
# self.output_size = output_size
|
||||
# self.hidden_size = hidden_size
|
||||
# self.vocab_size = vocab_size
|
||||
# self.num_layers=1
|
||||
# self.num_directions=1
|
||||
# self.device = device
|
||||
#
|
||||
# self.embedding = nn.Embedding(vocab_size, hidden_size).to(device)
|
||||
# self.rnn = nn.GRU(
|
||||
# input_size=hidden_size,
|
||||
# hidden_size=hidden_size,
|
||||
# num_layers=self.num_layers,
|
||||
# bidirectional=(self.num_directions == 2),
|
||||
# batch_first=True
|
||||
# ).to(device)
|
||||
# self.projection = nn.Linear(self.num_layers * self.num_directions * self.hidden_size, output_size).to(device)
|
||||
#
|
||||
# def init_hidden(self, batch_size):
|
||||
# return torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(self.device)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# batch_size = x.shape[0]
|
||||
# x = self.embedding(x)
|
||||
# output, hn = self.rnn(x, self.init_hidden(batch_size))
|
||||
# hn = hn.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
|
||||
# hn = hn.permute(2, 0, 1, 3).reshape(batch_size, -1)
|
||||
# return self.projection(hn)
|
||||
#
|
||||
# def space_dimensions(self):
|
||||
# return self.output_size
|
Loading…
Reference in New Issue