This commit is contained in:
williamFalcon 2019-07-28 05:39:25 -07:00
commit 5db28899aa
9 changed files with 226 additions and 42 deletions

View File

@ -276,7 +276,6 @@ tensorboard --logdir /some/path
###### Training loop
- [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients)
- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate)
- [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs)
- [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop)
- [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping)

View File

@ -225,26 +225,27 @@ def validation_end(self, outputs):
def configure_optimizers(self)
```
Set up as many optimizers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
Lightning will call .backward() and .step() on each one. If you use 16 bit precision it will also handle that.
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
##### Return
List - List of optimizers
Tuple - List of optimizers and list of schedulers
**Example**
``` {.python}
# most cases
def configure_optimizers(self):
opt = Adam(lr=0.01)
return [opt]
opt = Adam(self.parameters(), lr=0.01)
return [opt], []
# gan example
# gan example, with scheduler for discriminator
def configure_optimizers(self):
generator_opt = Adam(lr=0.01)
disriminator_opt = Adam(lr=0.02)
return [generator_opt, disriminator_opt]
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
```
---
@ -431,4 +432,4 @@ def add_model_specific_args(parent_parser, root_dir):
parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False)
return parser
```
```

View File

@ -11,17 +11,6 @@ Accumulated gradients runs K small batches of size N before doing a backwards pa
trainer = Trainer(accumulate_grad_batches=1)
```
---
#### Anneal Learning rate
Cut the learning rate by 10 at every epoch listed in this list.
``` {.python}
# DEFAULT (don't anneal)
trainer = Trainer(lr_scheduler_milestones=None)
# cut LR by 10 at 100, 200, and 300 epochs
trainer = Trainer(lr_scheduler_milestones='100, 200, 300')
```
---
#### Force training for min or max epochs
It can be useful to force training for a minimum number of epochs or limit to a max number

View File

@ -66,7 +66,6 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file.
###### Training loop
- [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients)
- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate)
- [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs)
- [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop)
- [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping)

View File

@ -163,7 +163,8 @@ class LightningTemplateModel(LightningModule):
:return: list of optimizers
"""
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return [optimizer]
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
return [optimizer], [scheduler]
def __dataloader(self, train):
# init data generators
@ -220,7 +221,6 @@ class LightningTemplateModel(LightningModule):
# parser.set_defaults(gradient_clip=5.0)
# network params
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.add_argument('--in_features', default=28*28, type=int)
parser.add_argument('--out_features', default=10, type=int)
parser.add_argument('--hidden_dim', default=50000, type=int) # use 500 for CPU, 50000 for GPU to see speed difference

View File

@ -0,0 +1,204 @@
import torch.nn as nn
import numpy as np
from pytorch_lightning import LightningModule
from test_tube import HyperOptArgumentParser
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
class ExampleModel1(LightningModule):
"""
Sample model to show how to define a template
"""
def __init__(self, hparams):
# init superclass
super(ExampleModel1, self).__init__(hparams)
self.batch_size = hparams.batch_size
# build model
self.__build_model()
# ---------------------
# MODEL SETUP
# ---------------------
def __build_model(self):
"""
Layout model
:return:
"""
self.c_d1 = nn.Linear(in_features=self.hparams.in_features, out_features=self.hparams.hidden_dim)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, out_features=self.hparams.out_features)
# ---------------------
# TRAINING
# ---------------------
def forward(self, x):
x = self.c_d1(x)
x = F.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def training_step(self, data_batch):
"""
Called inside the training loop
:param data_batch:
:return:
"""
# forward pass
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
# calculate loss
loss_val = self.loss(y, y_hat)
tqdm_dic = {'jefe': 1}
return loss_val, tqdm_dic
def validation_step(self, data_batch):
"""
Called inside the validation loop
:param data_batch:
:return:
"""
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
output = {'y_hat': y_hat, 'val_loss': loss_val.item(), 'val_acc': val_acc}
return output
def validation_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
val_loss_mean = 0
accs = []
for output in outputs:
val_loss_mean += output['val_loss']
accs.append(output['val_acc'])
val_loss_mean /= len(outputs)
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': np.mean(accs)}
return tqdm_dic
def update_tng_log_metrics(self, logs):
return logs
# ---------------------
# MODEL SAVING
# ---------------------
def get_save_dict(self):
checkpoint = {
'state_dict': self.state_dict(),
}
return checkpoint
def load_model_specific(self, checkpoint):
self.load_state_dict(checkpoint['state_dict'])
pass
# ---------------------
# TRAINING SETUP
# ---------------------
def configure_optimizers(self):
"""
return whatever optimizers and (optionally) schedulers we want here
:return: list of optimizers
"""
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
self.optimizers = [optimizer]
self.schedulers = []
return self.optimizers, self.schedulers
def __dataloader(self, train):
# init data generators
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
@data_loader
def tng_dataloader(self):
if self._tng_dataloader is None:
try:
self._tng_dataloader = self.__dataloader(train=True)
except Exception as e:
print(e)
raise e
return self._tng_dataloader
@property
def val_dataloader(self):
if self._val_dataloader is None:
try:
self._val_dataloader = self.__dataloader(train=False)
except Exception as e:
print(e)
raise e
return self._val_dataloader
@property
def test_dataloader(self):
if self._test_dataloader is None:
try:
self._test_dataloader = self.__dataloader(train=False)
except Exception as e:
print(e)
raise e
return self._test_dataloader
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# param overwrites
# parser.set_defaults(gradient_clip=5.0)
# network params
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.add_argument('--in_features', default=28*28)
parser.add_argument('--hidden_dim', default=500)
parser.add_argument('--out_features', default=10)
# data
parser.add_argument('--data_root', default='/Users/williamfalcon/Developer/personal/research_lib/research_proj/datasets/mnist', type=str)
# training params (opt)
parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005],
tunable=False)
parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False)
return parser

View File

@ -10,7 +10,6 @@ import re
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import MultiStepLR
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
@ -71,7 +70,6 @@ class Trainer(TrainerIO):
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0,
val_check_interval=0.95,
log_save_interval=100, add_log_row_interval=10,
lr_scheduler_milestones=None,
distributed_backend='dp',
use_amp=False,
print_nan_grads=False,
@ -104,7 +102,6 @@ class Trainer(TrainerIO):
:param val_check_interval:
:param log_save_interval:
:param add_log_row_interval:
:param lr_scheduler_milestones:
:param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel
:param use_amp:
:param print_nan_grads:
@ -141,7 +138,6 @@ class Trainer(TrainerIO):
self.early_stop_callback = early_stop_callback
self.min_nb_epochs = min_nb_epochs
self.nb_sanity_val_steps = nb_sanity_val_steps
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
self.lr_schedulers = []
self.amp_level = amp_level
self.print_nan_grads = print_nan_grads
@ -443,7 +439,7 @@ class Trainer(TrainerIO):
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
self.optimizers, self.lr_schedulers = model.configure_optimizers()
self.__run_pretrain_routine(model)
@ -455,7 +451,7 @@ class Trainer(TrainerIO):
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
self.optimizers, self.lr_schedulers = model.configure_optimizers()
model.cuda(self.data_parallel_device_ids[0])
@ -509,7 +505,7 @@ class Trainer(TrainerIO):
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
self.optimizers, self.lr_schedulers = model.configure_optimizers()
# MODEL
# copy model to each gpu
@ -589,12 +585,6 @@ class Trainer(TrainerIO):
# init training constants
self.__layout_bookeeping()
# add lr schedulers
if self.lr_scheduler_milestones is not None:
for optimizer in self.optimizers:
scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones)
self.lr_schedulers.append(scheduler)
# print model summary
if self.proc_rank == 0 and self.print_weights_summary:
ref_model.summarize()
@ -628,8 +618,9 @@ class Trainer(TrainerIO):
# run all epochs
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
# update the lr scheduler
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
model = self.__get_model()
model.current_epoch = epoch_nb

View File

@ -58,7 +58,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
def configure_optimizers(self):
"""
Return array of optimizers
Return a list of optimizers and a list of schedulers (could be empty)
:return:
"""
raise NotImplementedError

View File

@ -179,8 +179,9 @@ class LightningTestModel(LightningModule):
return whatever optimizers we want here
:return: list of optimizers
"""
# try no scheduler for this model (testing purposes)
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return [optimizer]
return [optimizer], []
def __dataloader(self, train):
# init data generators