fixed bug: we were applying sigmoid function 2 times when training the Attention-based aggregator
This commit is contained in:
parent
fc98bc3924
commit
7041f7b651
|
@ -178,4 +178,5 @@ cython_debug/
|
||||||
out/*
|
out/*
|
||||||
amazon_cateogories.bu.txt
|
amazon_cateogories.bu.txt
|
||||||
models/*
|
models/*
|
||||||
scripts/
|
scripts/
|
||||||
|
logger/*
|
|
@ -109,9 +109,7 @@ class GlamiDataset:
|
||||||
def get_label_binarizer(self, labels):
|
def get_label_binarizer(self, labels):
|
||||||
mlb = LabelBinarizer()
|
mlb = LabelBinarizer()
|
||||||
mlb.fit(labels)
|
mlb.fit(labels)
|
||||||
print(
|
print(f"- Label binarizer initialized with {len(mlb.classes_)} labels")
|
||||||
f"- Label binarizer initialized with the following labels:\n{mlb.classes_}"
|
|
||||||
)
|
|
||||||
return mlb
|
return mlb
|
||||||
|
|
||||||
def binarize_labels(self, labels):
|
def binarize_labels(self, labels):
|
||||||
|
|
|
@ -23,8 +23,9 @@ def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||||
|
|
||||||
|
|
||||||
def log_eval(l_eval, phase="training"):
|
def log_eval(l_eval, phase="training", verbose=True):
|
||||||
print(f"\n[Results {phase}]")
|
if verbose:
|
||||||
|
print(f"\n[Results {phase}]")
|
||||||
metrics = []
|
metrics = []
|
||||||
for lang in l_eval.keys():
|
for lang in l_eval.keys():
|
||||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||||
|
@ -32,9 +33,10 @@ def log_eval(l_eval, phase="training"):
|
||||||
if phase != "validation":
|
if phase != "validation":
|
||||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||||
averages = np.mean(np.array(metrics), axis=0)
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
print(
|
if verbose:
|
||||||
"Averages: MF1, mF1, MK, mK",
|
print(
|
||||||
np.round(averages, 3),
|
"Averages: MF1, mF1, MK, mK",
|
||||||
"\n",
|
np.round(averages, 3),
|
||||||
)
|
"\n",
|
||||||
|
)
|
||||||
return averages
|
return averages
|
||||||
|
|
|
@ -156,7 +156,7 @@ class GeneralizedFunnelling:
|
||||||
if "attn" in self.aggfunc:
|
if "attn" in self.aggfunc:
|
||||||
attn_stacking = self.aggfunc.split("_")[1]
|
attn_stacking = self.aggfunc.split("_")[1]
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.lr_transformer,
|
lr=self.lr_transformer,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
|
@ -173,6 +173,7 @@ class GeneralizedFunnelling:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._model_id = get_unique_id(
|
self._model_id = get_unique_id(
|
||||||
|
self.dataset_name,
|
||||||
self.posteriors_vgf,
|
self.posteriors_vgf,
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
|
@ -376,7 +377,7 @@ class GeneralizedFunnelling:
|
||||||
def _load_transformer(self):
|
def _load_transformer(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_attn_agg_dim(self, attn_stacking_type="concat"):
|
def get_attn_agg_dim(self, attn_stacking_type):
|
||||||
if self.probabilistic and "attn" not in self.aggfunc:
|
if self.probabilistic and "attn" not in self.aggfunc:
|
||||||
return len(self.first_tier_learners) * self.num_labels
|
return len(self.first_tier_learners) * self.num_labels
|
||||||
elif self.probabilistic and "attn" in self.aggfunc:
|
elif self.probabilistic and "attn" in self.aggfunc:
|
||||||
|
@ -398,11 +399,11 @@ def get_params(optimc=False):
|
||||||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||||
|
|
||||||
|
|
||||||
def get_unique_id(posterior, multilingual, wce, transformer, aggfunc):
|
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now().strftime("%y%m%d")
|
now = datetime.now().strftime("%y%m%d")
|
||||||
model_id = ""
|
model_id = dataset_name
|
||||||
model_id += "p" if posterior else ""
|
model_id += "p" if posterior else ""
|
||||||
model_id += "m" if multilingual else ""
|
model_id += "m" if multilingual else ""
|
||||||
model_id += "w" if wce else ""
|
model_id += "w" if wce else ""
|
||||||
|
|
|
@ -126,7 +126,7 @@ class Trainer:
|
||||||
self.earlystopping = EarlyStopping(
|
self.earlystopping = EarlyStopping(
|
||||||
patience=patience,
|
patience=patience,
|
||||||
checkpoint_path=checkpoint_path,
|
checkpoint_path=checkpoint_path,
|
||||||
verbose=True,
|
verbose=False,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -149,18 +149,19 @@ class Trainer:
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.train_epoch(train_dataloader, epoch)
|
self.train_epoch(train_dataloader, epoch)
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
if (epoch + 1) % self.evaluate_steps == 0:
|
||||||
metric_watcher = self.evaluate(eval_dataloader)
|
print_eval = (epoch + 1) % 25 == 0
|
||||||
|
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||||
if stop:
|
if stop:
|
||||||
print(
|
print(
|
||||||
f"- restoring best model from epoch {self.earlystopping.best_epoch}"
|
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||||
)
|
)
|
||||||
self.model = self.earlystopping.load_model(self.model).to(
|
self.model = self.earlystopping.load_model(self.model).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
self.train_epoch(eval_dataloader, epoch=epoch)
|
|
||||||
print(f"\n- last swipe on eval set")
|
print(f"\n- last swipe on eval set")
|
||||||
|
self.train_epoch(eval_dataloader, epoch=0)
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -180,7 +181,7 @@ class Trainer:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def evaluate(self, dataloader):
|
def evaluate(self, dataloader, print_eval=True):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
lY = defaultdict(list)
|
lY = defaultdict(list)
|
||||||
|
@ -204,7 +205,7 @@ class Trainer:
|
||||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
lY_hat[lang] = np.vstack(lY_hat[lang])
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
l_eval = evaluate(lY, lY_hat)
|
||||||
average_metrics = log_eval(l_eval, phase="validation")
|
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||||
return average_metrics[0] # macro-F1
|
return average_metrics[0] # macro-F1
|
||||||
|
|
||||||
|
|
||||||
|
@ -228,21 +229,23 @@ class EarlyStopping:
|
||||||
|
|
||||||
def __call__(self, validation, model, epoch):
|
def __call__(self, validation, model, epoch):
|
||||||
if validation > self.best_score:
|
if validation > self.best_score:
|
||||||
print(
|
if self.verbose:
|
||||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
print(
|
||||||
)
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||||
|
)
|
||||||
self.best_score = validation
|
self.best_score = validation
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
self.best_epoch = epoch
|
self.best_epoch = epoch
|
||||||
|
# print(f"- earlystopping: Saving best model from epoch {epoch}")
|
||||||
self.save_model(model)
|
self.save_model(model)
|
||||||
elif validation < (self.best_score + self.min_delta):
|
elif validation < (self.best_score + self.min_delta):
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
print(
|
if self.verbose:
|
||||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
print(
|
||||||
)
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||||
|
)
|
||||||
if self.counter >= self.patience:
|
if self.counter >= self.patience:
|
||||||
if self.verbose:
|
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def save_model(self, model):
|
def save_model(self, model):
|
||||||
|
@ -256,36 +259,35 @@ class EarlyStopping:
|
||||||
|
|
||||||
|
|
||||||
class AttentionModule(nn.Module):
|
class AttentionModule(nn.Module):
|
||||||
def __init__(self, embed_dim, num_heads, h_dim, out_dim):
|
def __init__(self, embed_dim, num_heads, h_dim, out_dim, aggfunc_type):
|
||||||
|
"""We are calling sigmoid on the evaluation loop (Trainer.evaluate), so we
|
||||||
|
are not applying explicitly here at training time. However, we should
|
||||||
|
explcitly squash outputs through the sigmoid at inference (self.transform) (???)
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.aggfunc = aggfunc_type
|
||||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)
|
||||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
# self.layer_norm = nn.LayerNorm(embed_dim)
|
||||||
self.linear = nn.Linear(embed_dim, out_dim)
|
if self.aggfunc == "concat":
|
||||||
|
self.linear = nn.Linear(embed_dim, out_dim)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def __call__(self, X):
|
def __call__(self, X):
|
||||||
out, attn_weights = self.attn(query=X, key=X, value=X)
|
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
out = self.layer_norm(out)
|
# out = self.layer_norm(out)
|
||||||
out = self.linear(out)
|
if self.aggfunc == "concat":
|
||||||
|
out = self.linear(out)
|
||||||
# out = self.sigmoid(out)
|
# out = self.sigmoid(out)
|
||||||
return out
|
return out
|
||||||
# out = self.relu(out)
|
|
||||||
# out = self.linear2(out)
|
|
||||||
# out = self.sigmoid(out)
|
|
||||||
|
|
||||||
def transform(self, X):
|
def transform(self, X):
|
||||||
return self.__call__(X)
|
"""explicitly calling sigmoid at inference time"""
|
||||||
# out, attn_weights = self.attn(query=X, key=X, value=X)
|
out, attn_weights = self.attn(query=X, key=X, value=X)
|
||||||
# out = self.layer_norm(out)
|
out = self.sigmoid(out)
|
||||||
# out = self.linear(out)
|
return out
|
||||||
# out = self.sigmoid(out)
|
|
||||||
# return out
|
|
||||||
# out = self.relu(out)
|
|
||||||
# out = self.linear2(out)
|
|
||||||
# out = self.sigmoid(out)
|
|
||||||
|
|
||||||
def save_pretrained(self, path):
|
def save_pretrained(self, path):
|
||||||
torch.save(self, f"{path}.pt")
|
torch.save(self, f"{path}.pt")
|
||||||
# torch.save(self.state_dict(), f"{path}.pt")
|
|
||||||
|
|
||||||
def from_pretrained(self, path):
|
def from_pretrained(self, path):
|
||||||
return torch.load(f"{path}.pt")
|
return torch.load(f"{path}.pt")
|
||||||
|
@ -316,7 +318,11 @@ class AttentionAggregator:
|
||||||
self.tr_batch_size = 512
|
self.tr_batch_size = 512
|
||||||
self.eval_batch_size = 1024
|
self.eval_batch_size = 1024
|
||||||
self.attn = AttentionModule(
|
self.attn = AttentionModule(
|
||||||
self.embed_dim, self.num_heads, self.h_dim, self.out_dim
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.h_dim,
|
||||||
|
self.out_dim,
|
||||||
|
aggfunc_type=self.stacking_type,
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
def fit(self, X, Y):
|
def fit(self, X, Y):
|
||||||
|
@ -345,7 +351,7 @@ class AttentionAggregator:
|
||||||
lr=self.lr,
|
lr=self.lr,
|
||||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
print_steps=25,
|
print_steps=25,
|
||||||
evaluate_step=50,
|
evaluate_step=10,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|
14
main.py
14
main.py
|
@ -17,6 +17,11 @@ TODO:
|
||||||
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
||||||
- test split in MultiNews dataset
|
- test split in MultiNews dataset
|
||||||
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
||||||
|
- FFNN posterior-probabilities' dependent
|
||||||
|
- re-init langs when loading VGFs?
|
||||||
|
- there is a mess about sigmoid in the Attention aggregator + and evaluation function (predict). We were applying sig() 2 times on the outputs (at pred and at eval)...
|
||||||
|
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
|
||||||
|
- aligner layer (suggestion by G.Puccetti)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,15 +130,16 @@ def main(args):
|
||||||
if args.load_trained is None and not args.nosave:
|
if args.load_trained is None and not args.nosave:
|
||||||
gfun.save(save_first_tier=True, save_meta=True)
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
preds = gfun.transform(lX)
|
# print("- Computing evaluation on training set")
|
||||||
|
# preds = gfun.transform(lX)
|
||||||
# train_eval = evaluate(lY, preds)
|
# train_eval = evaluate(lY, preds)
|
||||||
# log_eval(train_eval, phase="train")
|
# log_eval(train_eval, phase="train")
|
||||||
|
|
||||||
timetr = time()
|
timetr = time()
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
|
||||||
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
gfun_preds = gfun.transform(lX_te)
|
||||||
|
test_eval = evaluate(lY_te, gfun_preds)
|
||||||
log_eval(test_eval, phase="test")
|
log_eval(test_eval, phase="test")
|
||||||
|
|
||||||
timeval = time()
|
timeval = time()
|
||||||
|
@ -156,7 +162,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||||
parser.add_argument("-w", "--wce", action="store_true")
|
parser.add_argument("-w", "--wce", action="store_true")
|
||||||
parser.add_argument("-t", "--transformer", action="store_true")
|
parser.add_argument("-t", "--transformer", action="store_true")
|
||||||
parser.add_argument("--n_jobs", type=int, default=1)
|
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||||
parser.add_argument("--optimc", action="store_true")
|
parser.add_argument("--optimc", action="store_true")
|
||||||
parser.add_argument("--features", action="store_false")
|
parser.add_argument("--features", action="store_false")
|
||||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||||
|
|
Loading…
Reference in New Issue