diff --git a/decanlp/arguments.py b/decanlp/arguments.py index a3222fab..cc61900e 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -108,6 +108,7 @@ def parse(argv): parser.add_argument('--optimizer', default='adam', type=str, help='Adam or SGD') parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy') parser.add_argument('--sgd_lr', default=1.0, type=float, help='learning rate for SGD (if not using Adam)') + parser.add_argument('--weight_decay', default=0.0, type=float, help='weight L2 regularization') parser.add_argument('--load', default=None, type=str, help='path to checkpoint to load model from inside args.save') parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers') diff --git a/decanlp/train.py b/decanlp/train.py index 54bab2b1..2b0a4ce6 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -205,7 +205,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ round_progress = f'round_{rnd}:' if rounds else '' # validate - + deca_score = None if (val_every is not None and ((iteration % args.val_every == 0 % args.val_every) or @@ -384,11 +384,11 @@ def init_opt(args, model): opt = None if 'adam' in args.optimizer.lower(): if args.transformer_lr: - opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9) + opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay) else: - opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999)) + opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999), weight_decay=args.weight_decay) else: - opt = torch.optim.SGD(model.params, lr=args.sgd_lr) + opt = torch.optim.SGD(model.params, lr=args.sgd_lr, weight_decay=args.weight_decay,) return opt