changed wandb logging to a global level to keep track of all the VGFs and overall gFun
This commit is contained in:
parent
f32b9227ae
commit
56faaf2615
|
@ -30,18 +30,18 @@ def verbosity_eval(epoch, print_eval):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def format_langkey_wandb(lang_dict):
|
def format_langkey_wandb(lang_dict, vgf_name):
|
||||||
log_dict = {}
|
log_dict = {}
|
||||||
for metric, l_dict in lang_dict.items():
|
for metric, l_dict in lang_dict.items():
|
||||||
for lang, value in l_dict.items():
|
for lang, value in l_dict.items():
|
||||||
log_dict[f"language metric/{metric}/{lang}"] = value
|
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
def format_average_wandb(avg_dict):
|
def format_average_wandb(avg_dict, vgf_name):
|
||||||
log_dict = {}
|
log_dict = {}
|
||||||
for metric, value in avg_dict.items():
|
for metric, value in avg_dict.items():
|
||||||
log_dict[f"average metric/{metric}"] = value
|
log_dict[f"{vgf_name}/average metric/{metric}"] = value
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,14 +213,6 @@ class Trainer:
|
||||||
for k, v in _config.items():
|
for k, v in _config.items():
|
||||||
print(f"\t{k}: {v}")
|
print(f"\t{k}: {v}")
|
||||||
|
|
||||||
wandb_logger = wandb.init(
|
|
||||||
project="gfun",
|
|
||||||
entity="andreapdr",
|
|
||||||
name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}",
|
|
||||||
config=_config,
|
|
||||||
reinit=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
train_loss = self.train_epoch(train_dataloader, epoch)
|
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||||
|
|
||||||
|
@ -233,11 +225,11 @@ class Trainer:
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb_logger.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"loss/val": eval_loss,
|
f"{self.vgf_name}/loss/val": eval_loss,
|
||||||
**format_langkey_wandb(lang_metrics),
|
**format_langkey_wandb(lang_metrics, self.vgf_name),
|
||||||
**format_average_wandb(avg_metrics),
|
**format_average_wandb(avg_metrics, self.vgf_name),
|
||||||
},
|
},
|
||||||
commit=False,
|
commit=False,
|
||||||
)
|
)
|
||||||
|
@ -260,10 +252,12 @@ class Trainer:
|
||||||
if self.scheduler is not None:
|
if self.scheduler is not None:
|
||||||
self.scheduler.step(avg_metrics[self.monitored_metric])
|
self.scheduler.step(avg_metrics[self.monitored_metric])
|
||||||
|
|
||||||
wandb_logger.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"loss/train": train_loss,
|
f"{self.vgf_name}/loss/train": train_loss,
|
||||||
"learning rate": self.optimizer.param_groups[0]["lr"],
|
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
|
||||||
|
"lr"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -274,7 +268,7 @@ class Trainer:
|
||||||
|
|
||||||
def train_epoch(self, dataloader, epoch):
|
def train_epoch(self, dataloader, epoch):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
epoch_losses = []
|
batch_losses = []
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
y_hat = self.model(x.to(self.device))
|
y_hat = self.model(x.to(self.device))
|
||||||
|
@ -284,13 +278,13 @@ class Trainer:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
epoch_losses.append(loss.item())
|
batch_losses.append(loss.item()) # TODO: is this still on gpu?
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(
|
print(
|
||||||
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}"
|
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
|
||||||
)
|
)
|
||||||
return np.mean(epoch_losses)
|
return np.mean(batch_losses)
|
||||||
|
|
||||||
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
80
main.py
80
main.py
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import wandb
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
|
@ -11,19 +12,38 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- [!] add support for mT5
|
- Transformers VGFs:
|
||||||
- [!] log on wandb also the other VGF results + final results
|
- save/load for MT5ForSqeuenceClassification
|
||||||
- [!] CLS dataset is loading only "books" domain data
|
- freeze params method
|
||||||
- [!] documents should be trimmed to the same length (?)
|
- log on step rather than epoch?
|
||||||
- [!] overall gfun results logger
|
- General:
|
||||||
- add documentations sphinx
|
[!] zero-shot setup
|
||||||
- [!] zero-shot setup
|
- CLS dataset is loading only "books" domain data
|
||||||
- FFNN posterior-probabilities' dependent
|
- log on wandb also the other VGF results + final results
|
||||||
- re-init langs when loading VGFs?
|
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||||
- [!] experiment with weight init of Attention-aggregator
|
- Attention Aggregator:
|
||||||
|
- experiment with weight init of Attention-aggregator
|
||||||
|
- FFNN posterior-probabilities' dependent
|
||||||
|
- Docs:
|
||||||
|
- add documentations sphinx
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_name(args):
|
||||||
|
config_name = ""
|
||||||
|
if args.posteriors:
|
||||||
|
config_name += "P+"
|
||||||
|
if args.wce:
|
||||||
|
config_name += "W+"
|
||||||
|
if args.multilingual:
|
||||||
|
config_name += "M+"
|
||||||
|
if args.textual_transformer:
|
||||||
|
config_name += f"TT_{args.textual_trf_name}+"
|
||||||
|
if args.visual_transformer:
|
||||||
|
config_name += f"VT_{args.visual_trf_name}+"
|
||||||
|
return config_name.rstrip("+")
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset, args)
|
dataset = get_dataset(args.dataset, args)
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
|
@ -86,27 +106,53 @@ def main(args):
|
||||||
n_jobs=args.n_jobs,
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gfun.get_config()
|
wandb.init(
|
||||||
|
project="gfun", name=f"gFun-{get_config_name(args)}"
|
||||||
|
) # TODO: Add config to log
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None and not args.nosave:
|
if args.load_trained is None and not args.nosave:
|
||||||
gfun.save(save_first_tier=True, save_meta=True)
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
# print("- Computing evaluation on training set")
|
|
||||||
# preds = gfun.transform(lX)
|
|
||||||
# train_eval = evaluate(lY, preds)
|
|
||||||
# log_eval(train_eval, phase="train")
|
|
||||||
|
|
||||||
timetr = time()
|
timetr = time()
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
|
||||||
gfun_preds = gfun.transform(lX_te)
|
gfun_preds = gfun.transform(lX_te)
|
||||||
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
||||||
log_eval(test_eval, phase="test", clf_type=args.clf_type)
|
avg_metrics_gfun, lang_metrics_gfun = log_eval(
|
||||||
|
test_eval, phase="test", clf_type=args.clf_type
|
||||||
|
)
|
||||||
|
|
||||||
timeval = time()
|
timeval = time()
|
||||||
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
||||||
|
|
||||||
|
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
|
||||||
|
if title_affix == "per language":
|
||||||
|
for metric, lang_values in gfun_res.items():
|
||||||
|
data = [[lang, v] for lang, v in lang_values.items()]
|
||||||
|
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"gFun/language {metric}": wandb.plot.bar(
|
||||||
|
table, "lang", metric, title=f"{metric} {title_affix}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data = [[metric, value] for metric, value in gfun_res.items()]
|
||||||
|
table = wandb.Table(data=data, columns=["metric", "value"])
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"gFun/average metric": wandb.plot.bar(
|
||||||
|
table, "metric", "value", title=f"metric {title_affix}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
wandb.log(gfun_res)
|
||||||
|
|
||||||
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||||
|
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
Loading…
Reference in New Issue