adding weight regularization option
This commit is contained in:
parent
f65d939f9a
commit
292419b299
|
@ -108,6 +108,7 @@ def parse(argv):
|
|||
parser.add_argument('--optimizer', default='adam', type=str, help='Adam or SGD')
|
||||
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
|
||||
parser.add_argument('--sgd_lr', default=1.0, type=float, help='learning rate for SGD (if not using Adam)')
|
||||
parser.add_argument('--weight_decay', default=0.0, type=float, help='weight L2 regularization')
|
||||
|
||||
parser.add_argument('--load', default=None, type=str, help='path to checkpoint to load model from inside args.save')
|
||||
parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers')
|
||||
|
|
|
@ -205,7 +205,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
|||
round_progress = f'round_{rnd}:' if rounds else ''
|
||||
|
||||
# validate
|
||||
|
||||
|
||||
deca_score = None
|
||||
if (val_every is not None and
|
||||
((iteration % args.val_every == 0 % args.val_every) or
|
||||
|
@ -384,11 +384,11 @@ def init_opt(args, model):
|
|||
opt = None
|
||||
if 'adam' in args.optimizer.lower():
|
||||
if args.transformer_lr:
|
||||
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9)
|
||||
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(0.9, 0.98), eps=1e-9, weight_decay=args.weight_decay)
|
||||
else:
|
||||
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999))
|
||||
opt = torch.optim.Adam(model.params, lr=args.lr_rate, betas=(args.beta0, 0.999), weight_decay=args.weight_decay)
|
||||
else:
|
||||
opt = torch.optim.SGD(model.params, lr=args.sgd_lr)
|
||||
opt = torch.optim.SGD(model.params, lr=args.sgd_lr, weight_decay=args.weight_decay,)
|
||||
return opt
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue