updated argparse

This commit is contained in:
Andrea Pedrotti 2023-03-14 11:54:40 +01:00
parent 5e41b4517a
commit fece8d059e
1 changed files with 5 additions and 5 deletions

10
main.py
View File

@ -58,7 +58,7 @@ def main(args):
wce=args.wce, wce=args.wce,
# Transformer VGF params -------------- # Transformer VGF params --------------
textual_transformer=args.textual_transformer, textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name, textual_transformer_name=args.textual_trf_name,
batch_size=args.batch_size, batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size, eval_batch_size=args.eval_batch_size,
epochs=args.epochs, epochs=args.epochs,
@ -70,14 +70,14 @@ def main(args):
device=args.device, device=args.device,
# Visual Transformer VGF params -------------- # Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer, visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name, visual_transformer_name=args.visual_trf_name,
# batch_size=args.batch_size, # batch_size=args.batch_size,
# epochs=args.epochs, # epochs=args.epochs,
# lr=args.lr, # lr=args.lr,
# patience=args.patience, # patience=args.patience,
# evaluate_step=args.evaluate_step, # evaluate_step=args.evaluate_step,
# device="cuda", # device="cuda",
# General params ---------------------- # General params ---------------------
probabilistic=args.features, probabilistic=args.features,
aggfunc=args.aggfunc, aggfunc=args.aggfunc,
optimc=args.optimc, optimc=args.optimc,
@ -133,7 +133,7 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false") parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean") parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
@ -143,7 +143,7 @@ if __name__ == "__main__":
parser.add_argument("--patience", type=int, default=5) parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters -------------- # Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit") parser.add_argument("--visual_trf_name", type=str, default="vit")
args = parser.parse_args() args = parser.parse_args()