gfun_multimodal/gfun/vgfs/commons.py

389 lines
12 KiB
Python

import os
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 10
def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def XdotM(X, M, sif):
E = X.dot(M)
if sif:
E = remove_pc(E, npc=1)
return E
def remove_pc(X, npc=1):
"""
Remove the projection on the principal components
:param X: X[i,:] is a data point
:param npc: number of principal components to remove
:return: XX[i, :] is the data point after removing its projection
"""
pc = compute_pc(X, npc)
if npc == 1:
XX = X - X.dot(pc.transpose()) * pc
else:
XX = X - X.dot(pc.transpose()).dot(pc)
return XX
def compute_pc(X, npc=1):
"""
Compute the principal components.
:param X: X[i,:] is a data point
:param npc: number of principal components to remove
:return: component_[i,:] is the i-th pc
"""
if isinstance(X, np.matrix):
X = np.asarray(X)
svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
svd.fit(X)
return svd.components_
def predict(logits, classification_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if classification_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel":
prediction = torch.argmax(logits, dim=1).view(-1, 1)
else:
print("unknown classification type")
return prediction.detach().cpu().numpy()
class TfidfVectorizerMultilingual:
def __init__(self, **kwargs):
self.kwargs = kwargs
def fit(self, lX, ly=None):
self.langs = sorted(lX.keys())
self.vectorizer = {
l: TfidfVectorizer(**self.kwargs).fit(lX[l]) for l in self.langs
}
return self
def transform(self, lX):
return {l: self.vectorizer[l].transform(lX[l]) for l in self.langs}
def fit_transform(self, lX, ly=None):
return self.fit(lX, ly).transform(lX)
def vocabulary(self, l=None):
if l is None:
return {l: self.vectorizer[l].vocabulary_ for l in self.langs}
else:
return self.vectorizer[l].vocabulary_
def get_analyzer(self, l=None):
if l is None:
return {l: self.vectorizer[l].build_analyzer() for l in self.langs}
else:
return self.vectorizer[l].build_analyzer()
class Trainer:
def __init__(
self,
model,
optimizer_name,
device,
loss_fn,
lr,
print_steps,
evaluate_step,
patience,
experiment_name,
checkpoint_path,
):
self.device = device
self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path=checkpoint_path,
verbose=True,
experiment_name=experiment_name,
)
def init_optimizer(self, optimizer_name, lr):
if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10):
print(
f"""- Training params for {self.experiment_name}:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}\n"""
)
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
metric_watcher = self.evaluate(eval_dataloader)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
)
self.model = self.earlystopping.load_model(self.model).to(
self.device
)
break
# TODO: maybe a lower lr?
self.train_epoch(eval_dataloader, epoch=epoch)
print(f"\n- last swipe on eval set")
self.earlystopping.save_model(self.model)
return self.model
def train_epoch(self, dataloader, epoch):
self.model.train()
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
else:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if b_idx % self.print_steps == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
def evaluate(self, dataloader):
self.model.eval()
lY = defaultdict(list)
lY_hat = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, SequenceClassifierOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
else:
loss = self.loss_fn(y_hat, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel")
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred)
for lang in lY:
lY[lang] = np.vstack(lY[lang])
lY_hat[lang] = np.vstack(lY_hat[lang])
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation")
return average_metrics[0] # macro-F1
class EarlyStopping:
# TODO: add checkpointing + restore model if early stopping + last swipe on validation set
def __init__(
self,
patience,
checkpoint_path,
experiment_name,
min_delta=0,
verbose=True,
):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = 0
self.best_epoch = None
self.verbose = verbose
self.checkpoint_path = checkpoint_path
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation > self.best_score:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
)
self.best_score = validation
self.counter = 0
self.best_epoch = epoch
self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience:
if self.verbose:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True
def save_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
os.makedirs(_checkpoint_dir, exist_ok=True)
model.save_pretrained(_checkpoint_dir)
def load_model(self, model):
_checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name)
return model.from_pretrained(_checkpoint_dir)
class AttentionModule(nn.Module):
def __init__(self, embed_dim, num_heads, out_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, out_dim)
def __call__(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
out = self.linear(attn_out)
return out
def transform(self, X):
attn_out, attn_weights = self.attn(query=X, key=X, value=X)
return attn_out
def save_pretrained(self, path):
torch.save(self.state_dict(), f"{path}.pt")
def _wtf(self):
print("wtf")
class AttentionAggregator:
def __init__(self, embed_dim, out_dim, epochs, num_heads=1, device="cpu"):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.device = device
self.epochs = epochs
self.attn = AttentionModule(embed_dim, num_heads, out_dim).to(self.device)
def fit(self, X, Y):
print("- fitting Attention-based aggregating function")
hstacked_X = self.stack(X)
dataset = AggregatorDatasetTorch(hstacked_X, Y)
tra_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
experiment_name = "attention_aggregator"
trainer = Trainer(
self.attn,
optimizer_name="adamW",
lr=1e-3,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=100,
evaluate_step=1000,
patience=10,
experiment_name=experiment_name,
device=self.device,
checkpoint_path="models/aggregator",
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=tra_dataloader,
epochs=self.epochs,
)
return self
def transform(self, X):
# TODO: implement transform
h_stacked = self.stack(X)
dataset = AggregatorDatasetTorch(h_stacked, lY=None, split="whole")
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
_embeds = []
l_embeds = defaultdict(list)
self.attn.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.attn.transform(input_ids)
_embeds.append((out.cpu().numpy(), lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def stack(self, data):
hstack = self._hstack(data)
return hstack
def _hstack(self, data):
_langs = data[0].keys()
l_projections = {}
for l in _langs:
l_projections[l] = torch.tensor(
np.hstack([view[l] for view in data]), dtype=torch.float32
)
return l_projections
def _vstack(self, data):
return torch.vstack()
class AggregatorDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]