adding weight regularization option

This commit is contained in:
mehrad 2019-03-13 14:19:41 -07:00 committed by Giovanni Campagna
parent f65d939f9a
commit 292419b299
2 changed files with 5 additions and 4 deletions

View File

@ -108,6 +108,7 @@ def parse(argv):
parser.add_argument('--optimizer', default='adam', type=str, help='Adam or SGD')
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
parser.add_argument('--sgd_lr', default=1.0, type=float, help='learning rate for SGD (if not using Adam)')
parser.add_argument('--weight_decay', default=0.0, type=float, help='weight L2 regularization')
parser.add_argument('--load', default=None, type=str, help='path to checkpoint to load model from inside args.save')
parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers')

View File

@ -205,7 +205,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
round_progress = f'round_{rnd}:' if rounds else ''
# validate
deca_score = None
if (val_every is not None and
((iteration % args.val_every == 0 % args.val_every) or
@ -384,11 +384,11 @@ def init_opt(args, model):
opt = None
if 'adam' in args.optimizer.lower():
if args.transformer_lr:
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9)
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay)
else:
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999))
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999), weight_decay=args.weight_decay)
else:
opt = torch.optim.SGD(model.params, lr=args.sgd_lr)
opt = torch.optim.SGD(model.params, lr=args.sgd_lr, weight_decay=args.weight_decay,)
return opt