2020-05-05 02:16:54 +00:00
|
|
|
.. testsetup:: *
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
|
from pytorch_lightning.trainer.trainer import Trainer
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
|
|
import sys
|
|
|
|
sys.argv = ['foo']
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Hyperparameters
|
|
|
|
---------------
|
|
|
|
Lightning has utilities to interact seamlessly with the command line ArgumentParser
|
|
|
|
and plays well with the hyperparameter optimization framework of your choice.
|
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
ArgumentParser
|
|
|
|
^^^^^^^^^^^^^^
|
|
|
|
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-04-26 13:20:06 +00:00
|
|
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument('--layer_1_dim', type=int, default=128)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
This allows you to call your program like so:
|
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
python trainer.py --layer_1_dim 64
|
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
2020-04-26 13:20:06 +00:00
|
|
|
|
|
|
|
Argparser Best Practices
|
2020-03-03 15:52:16 +00:00
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
2020-04-26 13:20:06 +00:00
|
|
|
It is best practice to layer your arguments in three sections.
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-04-26 14:57:26 +00:00
|
|
|
1. Trainer args (gpus, num_nodes, etc...)
|
|
|
|
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...)
|
|
|
|
3. Program arguments (data_path, cluster_email, etc...)
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
We can do this as follows. First, in your LightningModule, define the arguments
|
|
|
|
specific to that module. Remember that data splits or data paths may also be specific to
|
|
|
|
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10).
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class LitModel(LightningModule):
|
2020-04-26 13:20:06 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_model_specific_args(parent_parser):
|
|
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
|
|
parser.add_argument('--encoder_layers', type=int, default=12)
|
|
|
|
parser.add_argument('--data_path', type=str, default='/some/path')
|
|
|
|
return parser
|
|
|
|
|
|
|
|
Now in your main trainer file, add the Trainer args, the program args, and add the model args
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
2020-04-26 13:20:06 +00:00
|
|
|
|
|
|
|
# ----------------
|
|
|
|
# trainer_main.py
|
|
|
|
# ----------------
|
2020-03-03 15:52:16 +00:00
|
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser()
|
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
# add PROGRAM level args
|
|
|
|
parser.add_argument('--conda_env', type=str, default='some_name')
|
|
|
|
parser.add_argument('--notification_email', type=str, default='will@email.com')
|
|
|
|
|
|
|
|
# add model specific args
|
|
|
|
parser = LitModel.add_model_specific_args(parser)
|
2020-03-06 19:43:17 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
# add all the available trainer options to argparse
|
|
|
|
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
|
2020-05-05 02:16:54 +00:00
|
|
|
parser = Trainer.add_argparse_args(parser)
|
2020-03-06 19:43:17 +00:00
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
args = parser.parse_args()
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
Now you can call run your program like so
|
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
|
|
|
|
|
|
|
|
Finally, make sure to start the training like so:
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. code-block:: python
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
# init the trainer like this
|
|
|
|
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
|
|
|
|
|
|
|
|
# NOT like this
|
|
|
|
trainer = Trainer(gpus=hparams.gpus, ...)
|
|
|
|
|
|
|
|
# init the model with Namespace directly
|
|
|
|
model = LitModel(args)
|
|
|
|
|
|
|
|
# or init the model with all the key-value pairs
|
|
|
|
dict_args = vars(args)
|
|
|
|
model = LitModel(**dict_args)
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
LightningModule hyperparameters
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
2020-06-08 11:19:34 +00:00
|
|
|
Often times we train many versions of a model. You might share that model or come back to it a few months later
|
|
|
|
at which point it is very useful to know how that model was trained (ie: what learning_rate, neural network, etc...).
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
Lightning has a few ways of saving that information for you in checkpoints and yaml files. The goal here is to
|
|
|
|
improve readability and reproducibility
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
1. The first way is to ask lightning to save the values anything in the __init__ for you to the checkpoint. This also
|
|
|
|
makes those values available via `self.hparams`.
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, layer_1_dim=128, learning_rate=1e-2, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
# call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
|
|
|
|
self.save_hyperparameters()
|
|
|
|
|
|
|
|
# equivalent
|
2020-06-25 02:28:38 +00:00
|
|
|
self.save_hyperparameters('layer_1_dim', 'learning_rate')
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
# this now works
|
|
|
|
self.hparams.layer_1_dim
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
2. Sometimes your init might have objects or other parameters you might not want to save.
|
|
|
|
In that case, choose only a few
|
|
|
|
|
|
|
|
.. code-block:: python
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
class LitMNIST(LightningModule):
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
def __init__(self, loss_fx, generator_network, layer_1_dim=128 **kwargs):
|
2020-05-05 02:16:54 +00:00
|
|
|
super().__init__()
|
2020-05-24 22:59:08 +00:00
|
|
|
self.layer_1_dim = layer_1_dim
|
2020-06-08 11:19:34 +00:00
|
|
|
self.loss_fx = loss_fx
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
# call this to save (layer_1_dim=128) to the checkpoint
|
2020-06-25 02:28:38 +00:00
|
|
|
self.save_hyperparameters('layer_1_dim')
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
# to load specify the other args
|
|
|
|
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
3. Assign to `self.hparams`. Anything assigned to `self.hparams` will also be saved automatically
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# using a argparse.Namespace
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
|
|
|
|
def __init__(self, hparams, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
self.hparams = hparams
|
|
|
|
|
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
|
|
|
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
|
|
|
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
|
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
4. You can also save full objects such as `dict` or `Namespace` to the checkpoint.
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
2020-03-06 19:53:27 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
# using a argparse.Namespace
|
2020-05-24 22:59:08 +00:00
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
def __init__(self, conf, *args, **kwargs):
|
2020-05-24 22:59:08 +00:00
|
|
|
super().__init__()
|
2020-06-08 11:19:34 +00:00
|
|
|
self.hparams = conf
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
# equivalent
|
|
|
|
self.save_hyperparameters(conf)
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
|
|
|
|
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
|
|
|
|
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
conf = OmegaConf.create(...)
|
|
|
|
model = LitMNIST(conf)
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
# this works
|
|
|
|
model.hparams.anything
|
2020-03-06 19:53:27 +00:00
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
2020-03-06 19:53:27 +00:00
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Trainer args
|
|
|
|
^^^^^^^^^^^^
|
2020-04-26 13:20:06 +00:00
|
|
|
To recap, add ALL possible trainer flags to the argparser and init the Trainer this way
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
2020-04-26 13:20:06 +00:00
|
|
|
hparams = parser.parse_args()
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
trainer = Trainer.from_argparse_args(hparams)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
# or if you need to pass in callbacks
|
|
|
|
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
Multiple Lightning Modules
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
We often have multiple Lightning Modules where each one has different arguments. Instead of
|
|
|
|
polluting the main.py file, the LightningModule lets you define arguments for each one.
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
|
|
|
|
|
|
|
class LitMNIST(LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
def __init__(self, layer_1_dim, **kwargs):
|
2020-04-17 18:45:23 +00:00
|
|
|
super().__init__()
|
2020-05-24 22:59:08 +00:00
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_model_specific_args(parent_parser):
|
2020-05-24 22:59:08 +00:00
|
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
2020-03-03 15:52:16 +00:00
|
|
|
parser.add_argument('--layer_1_dim', type=int, default=128)
|
|
|
|
return parser
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
.. testcode::
|
|
|
|
|
|
|
|
class GoodGAN(LightningModule):
|
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
def __init__(self, encoder_layers, **kwargs):
|
2020-04-17 18:45:23 +00:00
|
|
|
super().__init__()
|
2020-05-24 22:59:08 +00:00
|
|
|
self.encoder = Encoder(layers=encoder_layers)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_model_specific_args(parent_parser):
|
2020-05-24 22:59:08 +00:00
|
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
2020-03-03 15:52:16 +00:00
|
|
|
parser.add_argument('--encoder_layers', type=int, default=12)
|
|
|
|
return parser
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
|
|
|
|
Now we can allow each model to inject the arguments it needs in the ``main.py``
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def main(args):
|
2020-05-24 22:59:08 +00:00
|
|
|
dict_args = vars(args)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
# pick model
|
|
|
|
if args.model_name == 'gan':
|
2020-05-24 22:59:08 +00:00
|
|
|
model = GoodGAN(**dict_args)
|
2020-03-03 15:52:16 +00:00
|
|
|
elif args.model_name == 'mnist':
|
2020-05-24 22:59:08 +00:00
|
|
|
model = LitMNIST(**dict_args)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
trainer = Trainer.from_argparse_args(args)
|
2020-03-03 15:52:16 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser = Trainer.add_argparse_args(parser)
|
|
|
|
|
|
|
|
# figure out which model to use
|
|
|
|
parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')
|
2020-04-26 13:20:06 +00:00
|
|
|
|
|
|
|
# THIS LINE IS KEY TO PULL THE MODEL NAME
|
2020-05-04 15:40:50 +00:00
|
|
|
temp_args, _ = parser.parse_known_args()
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
# let the model add what it wants
|
|
|
|
if temp_args.model_name == 'gan':
|
|
|
|
parser = GoodGAN.add_model_specific_args(parser)
|
|
|
|
elif temp_args.model_name == 'mnist':
|
2020-03-06 11:25:24 +00:00
|
|
|
parser = LitMNIST.add_model_specific_args(parser)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# train
|
|
|
|
main(args)
|
|
|
|
|
2020-05-05 02:16:54 +00:00
|
|
|
and now we can train MNIST or the GAN using the command line interface!
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
$ python main.py --model_name gan --encoder_layers 24
|
|
|
|
$ python main.py --model_name mnist --layer_1_dim 128
|
|
|
|
|
2020-06-19 06:38:10 +00:00
|
|
|
----------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Hyperparameter Optimization
|
|
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Lightning is fully compatible with the hyperparameter optimization libraries!
|
|
|
|
Here are some useful ones:
|
|
|
|
|
|
|
|
- `Hydra <https://medium.com/pytorch/hydra-a-fresh-look-at-configuration-for-machine-learning-projects-50583186b710>`_
|
|
|
|
- `Optuna <https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py>`_
|
2020-07-20 17:27:16 +00:00
|
|
|
- `Ray Tune <https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html>`_
|