diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c6420aa90a..f0413d4de8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -74,17 +74,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def print(self, *args, **kwargs) -> None: r""" - Prints only from process 0. Use this in any distributed mode to log only once + Prints only from process 0. Use this in any distributed mode to log only once. Args: - x (object): The thing to print + *args: The thing to print. Will be passed to Python's built-in print function. + **kwargs: Will be passed to Python's built-in print function. + + Example: - Examples: .. code-block:: python - # example if we were using this model as a feature extractor def forward(self, x): - self.print(x, 'in loader') + self.print(x, 'in forward') """ if self.trainer.proc_rank == 0: @@ -93,15 +94,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): @abstractmethod def forward(self, *args, **kwargs): r""" - Same as torch.nn.Module.forward(), however in Lightning you want this to define - the operations you want to use for prediction (ie: on a server or as a feature extractor). + Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define + the operations you want to use for prediction (i.e.: on a server or as a feature extractor). - Normally you'd call self() from your training_step() method. + Normally you'd call ``self()`` from your :meth:`training_step` method. This makes it easy to write a complex system for training with the outputs you'd want in a prediction setting. Args: - x (tensor): Whatever you decide to define in the forward method + *args: Whatever you decide to pass into the forward method. + **kwargs: Keyword arguments are also possible. Return: Predicted output @@ -142,23 +144,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def training_step(self, *args, **kwargs) -> Union[ int, Dict[str, Union[Tensor, Dict[str, Tensor]]] ]: - r"""return loss, dict with metrics for tqdm + r""" + Here you compute and return the training loss and some additional metrics for e.g. + the progress bar or logger. Args: - batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your - dataloader. A tensor, tuple or list + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): If using multiple optimizers, this argument will also be present. - hiddens(:`Tensor `_): - Passed in if truncated_bptt_steps > 0. + optimizer_idx (int): When using multiple optimizers, this argument will also be present. + hiddens(:class:`~torch.Tensor`): Passed in if + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. Return: - dict with loss key and optional log, progress keys - if implementing training_step, return whatever you need in that step: + Dict with loss key and optional log or progress bar keys. + When implementing :meth:`training_step`, return whatever you need in that step: - - loss -> tensor scalar [REQUIRED] - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + - loss -> tensor scalar **REQUIRED** + - progress_bar -> Dict for progress bar display. Must have only tensors + - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. @@ -188,11 +192,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): return output If you define multiple optimizers, this step will be called with an additional - `optimizer_idx` param. + ``optimizer_idx`` parameter. .. code-block:: python - # Multiple optimizers (ie: GANs) + # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder @@ -201,13 +205,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): If you add truncated back propagation through time you will also get an additional - argument with the hidden states of the previous step. + argument with the hidden states of the previous step. .. code-block:: python # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): - # hiddens are the hiddens from the previous truncated backprop step + # hiddens are the hidden states from the previous truncated backprop step ... out, hiddens = self.lstm(data, hiddens) ... @@ -218,30 +222,29 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): } You can also return a -1 instead of a dict to stop the current loop. This is useful - if you want to break out of the current training epoch early. + if you want to break out of the current training epoch early. Notes: - The presented loss value in progress bar is smooth (average) over last values, - so it differs from values set in train/validation step. + The loss value shown in the progress bar is smoothed (averaged) over the last values, + so it differs from the actual loss returned in train/validation step. """ warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer') def training_end(self, *args, **kwargs): """ Warnings: - Deprecated in v0.7.0. use training_step_end instead + Deprecated in v0.7.0. Use :meth:`training_step_end` instead. """ def training_epoch_end( self, outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: - """Called at the end of training epoch with the outputs of all training_steps + """Called at the end of the training epoch with the outputs of all training steps. .. code-block:: python # the pseudocode for these calls - train_outs = [] for train_batch in train_data: out = training_step(train_batch) @@ -249,24 +252,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): training_epoch_end(train_outs) Args: - outputs: List of outputs you defined in training_step, or if there are multiple - dataloaders, a list containing a list of outputs for each dataloader + outputs: List of outputs you defined in :meth:`training_step`, or if there are + multiple dataloaders, a list containing a list of outputs for each dataloader. Return: - Dict or OrderedDict + Dict or OrderedDict. May contain the following optional keys: - - log (metrics to be added to the logger ; only tensors) - - any metric used in a callback (e.g. early stopping). + - log (metrics to be added to the logger; only tensors) + - any metric used in a callback (e.g. early stopping). - .. note:: If this method is not overridden, this won't be called. + Note: + If this method is not overridden, this won't be called. - The outputs here are strictly for logging or progress bar. - If you don't need to display anything, don't return anything. - - If you want to manually set current step, you can specify the 'step' key in the 'log' Dict + - If you want to manually set current step, you can specify the 'step' key in the 'log' dict. Examples: - With a single dataloader + With a single dataloader: .. code-block:: python @@ -283,7 +287,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): } return results - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains + With multiple dataloaders, ``outputs`` will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each training step for that dataloader. @@ -310,11 +314,12 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): str, Union[Tensor, Dict[str, Tensor]] ]: """ - Use this when training with dp or ddp2 because training_step will operate - on only part of the batch. However, this is still optional + Use this when training with dp or ddp2 because :meth:`training_step` + will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss. - .. note:: If you later switch to ddp or some other mode, this will still be called + Note: + If you later switch to ddp or some other mode, this will still be called so that you don't have to change your code .. code-block:: python @@ -328,12 +333,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): batch_parts_outputs: What you return in `training_step` for each batch part. Return: - dictionary with loss key and optional log, progress keys: - - loss -> tensor scalar [REQUIRED] - - progress_bar -> Dict for progress bar display. Must have only tensors - - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + Dict with loss key and optional log or progress bar keys. - In this case you should define training_step_end to perform those calculations. + - loss -> tensor scalar **REQUIRED** + - progress_bar -> Dict for progress bar display. Must have only tensors + - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) Examples: .. code-block:: python @@ -366,13 +370,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): loss = nce_loss(loss) return {'loss': loss} - .. seealso:: - see the :ref:`multi-gpu-training` guide for more details. + See Also: + See the :ref:`multi-gpu-training` guide for more details. """ def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" - Operate on a single batch of data from the validation set + Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. .. code-block:: python @@ -385,15 +389,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): validation_epoch_end(val_outs) Args: - batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your - dataloader. A tensor, tuple or list + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (int): The index of this batch dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val datasets used) Return: - Dict or OrderedDict - passed to validation_epoch_end. - If you defined validation_step_end it will go to that first. + Dict or OrderedDict - passed to :meth:`validation_epoch_end`. + If you defined :meth:`validation_step_end` it will go to that first. .. code-block:: python @@ -434,7 +438,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # all optional... - # return whatever you need for the collation function validation_end + # return whatever you need for the collation function validation_epoch_end output = OrderedDict({ 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc), # everything must be a tensor @@ -451,21 +455,24 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def validation_step(self, batch, batch_idx, dataset_idx): # dataset_idx tells you which dataset this is. - .. note:: If you don't need to validate you don't need to implement this method. + Note: + If you don't need to validate you don't need to implement this method. - .. note:: When the validation_step is called, the model has been put in eval mode + Note: + When the :meth:`validation_step` is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled. """ def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ - Use this when validating with dp or ddp2 because validation_step will operate - on only part of the batch. However, this is still optional + Use this when validating with dp or ddp2 because :meth:`validation_step` + will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss. - .. note:: If you later switch to ddp or some other mode, this will still be called - so that you don't have to change your code + Note: + If you later switch to ddp or some other mode, this will still be called + so that you don't have to change your code. .. code-block:: python @@ -475,12 +482,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): validation_step_end(batch_parts_outputs) Args: - batch_parts_outputs: What you return in `validation_step` for each batch part. + batch_parts_outputs: What you return in :meth:`validation_step` + for each batch part. Return: - Dict or OrderedDict - passed to the validation_epoch_end - - In this case you should define validation_step_end to perform those calculations. + Dict or OrderedDict - passed to the :meth:`validation_epoch_end` method. Examples: .. code-block:: python @@ -513,14 +519,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): loss = nce_loss(loss) return {'loss': loss} - .. seealso:: - see the :ref:`multi-gpu-training` guide for more details. + See Also: + See the :ref:`multi-gpu-training` guide for more details. """ def validation_end(self, outputs): """ Warnings: - Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0 + Deprecated in v0.7.0. Use :meth:`validation_epoch_end` instead. + Will be removed in 1.0.0. """ def validation_epoch_end( @@ -528,12 +535,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """ - Called at end of validation epoch with the outputs of all validation_steps + Called at the end of the validation epoch with the outputs of all validation steps. .. code-block:: python # the pseudocode for these calls - val_outs = [] for val_batch in val_data: out = validation_step(train_batch) @@ -541,24 +547,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): validation_epoch_end(val_outs) Args: - outputs: List of outputs you defined in validation_step, or if there are multiple - dataloaders, a list containing a list of outputs for each dataloader + outputs: List of outputs you defined in :meth:`validation_step`, or if there + are multiple dataloaders, a list containing a list of outputs for each dataloader. Return: - Dict or OrderedDict + Dict or OrderedDict. May have the following optional keys: - - progress_bar (dict for progress bar display ; only tensors) - - log (dict of metrics to add to logger ; only tensors). + - progress_bar (dict for progress bar display; only tensors) + - log (dict of metrics to add to logger; only tensors). - .. note:: If you didn't define a validation_step, this won't be called. + Note: + If you didn't define a :meth:`validation_step`, this won't be called. - The outputs here are strictly for logging or progress bar. - If you don't need to display anything, don't return anything. - - If you want to manually set current step, you can specify the 'step' key in the 'log' Dict + - If you want to manually set current step, you can specify the 'step' key in the 'log' dict. Examples: - With a single dataloader + With a single dataloader: .. code-block:: python @@ -604,29 +611,29 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" - Operate on a single batch of data from the test set + Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy. .. code-block:: python # the pseudocode for these calls - test_outs = [] for test_batch in test_data: - out = test_step(train_batch) + out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs) Args: - batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your - dataloader. A tensor, tuple or list - batch_idx (int): The index of this batch + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test datasets used) + (only if multiple test datasets used). Return: - Dict or OrderedDict - passed to the test_step_end + Dict or OrderedDict - passed to the :meth:`test_epoch_end` method. + If you defined :meth:`test_step_end` it will go to that first. .. code-block:: python @@ -658,7 +665,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # all optional... - # return whatever you need for the collation function validation_end + # return whatever you need for the collation function test_epoch_end output = OrderedDict({ 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc), # everything must be a tensor @@ -667,30 +674,33 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): # return an optional dict return output - If you pass in multiple validation datasets, validation_step will have an additional + If you pass in multiple validation datasets, :meth:`test_step` will have an additional argument. - .. code-block:: python + .. code-block:: python - # CASE 2: multiple validation datasets - def test_step(self, batch, batch_idx, dataset_idx): - # dataset_idx tells you which dataset this is. + # CASE 2: multiple test datasets + def test_step(self, batch, batch_idx, dataset_idx): + # dataset_idx tells you which dataset this is. - .. note:: If you don't need to validate you don't need to implement this method. + Note: + If you don't need to validate you don't need to implement this method. - .. note:: When the test_step is called, the model has been put in eval mode and + Note: + When the :meth:`test_step` is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled. """ def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ - Use this when testing with dp or ddp2 because test_step will operate + Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss. - .. note:: If you later switch to ddp or some other mode, this will still be called - so that you don't have to change your code + Note: + If you later switch to ddp or some other mode, this will still be called + so that you don't have to change your code. .. code-block:: python @@ -700,12 +710,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): test_step_end(batch_parts_outputs) Args: - batch_parts_outputs: What you return in `training_step` for each batch part. + batch_parts_outputs: What you return in :meth:`test_step` for each batch part. Return: - Dict or OrderedDict - passed to the test_epoch_end - - In this case you should define test_step_end to perform those calculations. + Dict or OrderedDict - passed to the :meth:`test_epoch_end`. Examples: .. code-block:: python @@ -738,14 +746,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): loss = nce_loss(loss) return {'loss': loss} - .. seealso:: - see the :ref:`multi-gpu-training` guide for more details. + See Also: + See the :ref:`multi-gpu-training` guide for more details. """ def test_end(self, outputs): """ Warnings: - Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 + Deprecated in v0.7.0. Use :meth:`test_epoch_end` instead. + Will be removed in 1.0.0. """ def test_epoch_end( @@ -753,12 +762,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] ) -> Dict[str, Dict[str, Tensor]]: """ - Called at end of test epoch with the output of all test_steps. + Called at the end of a test epoch with the output of all test steps. .. code-block:: python # the pseudocode for these calls - test_outs = [] for test_batch in test_data: out = test_step(test_batch) @@ -766,22 +774,24 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): test_epoch_end(test_outs) Args: - outputs: List of outputs you defined in test_step, or if there are multiple - dataloaders, a list containing a list of outputs for each dataloader + outputs: List of outputs you defined in :meth:`test_step_end`, or if there + are multiple dataloaders, a list containing a list of outputs for each dataloader Return: - Dict or OrderedDict (dict): Dict has the following optional keys: - progress_bar -> Dict for progress bar display. Must have only tensors - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) + Dict or OrderedDict: Dict has the following optional keys: - .. note:: If you didn't define a test_step, this won't be called. + - progress_bar -> Dict for progress bar display. Must have only tensors. + - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc). + + Note: + If you didn't define a :meth:`test_step`, this won't be called. - The outputs here are strictly for logging or progress bar. - If you don't need to display anything, don't return anything. - If you want to manually set current step, specify it with the 'step' key in the 'log' Dict Examples: - With a single dataloader + With a single dataloader: .. code-block:: python @@ -834,13 +844,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): Override to init DDP in your own way or with your own wrapper. The only requirements are that: - 1. On a validation batch the call goes to model.validation_step. - 2. On a training batch the call goes to model.training_step. - 3. On a testing batch, the call goes to model.test_step + 1. On a validation batch the call goes to ``model.validation_step``. + 2. On a training batch the call goes to ``model.training_step``. + 3. On a testing batch, the call goes to ``model.test_step``.+ Args: - model: the LightningModule currently being optimized - device_ids: the list of GPU ids + model: the :class:`LightningModule` currently being optimized. + device_ids: the list of GPU ids. Return: DDP wrapped model @@ -868,14 +878,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def init_ddp_connection(self, proc_rank: int, world_size: int) -> None: r""" - Override to define your custom way of setting up a distributed environment. - Lightning's implementation uses env:// init by default and sets the first node as root. + Lightning's implementation uses ``env://`` init by default and sets the first node as root. Args: proc_rank: The current process rank within the node. - world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). + world_size: Number of GPUs being use across all nodes (num_nodes * num_gpus). Examples: .. code-block:: python @@ -952,13 +961,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): amp_level: str ) -> Tuple['LightningModule', List[Optimizer]]: r""" - Override to init AMP your own way - Must return a model and list of optimizers + Override to init AMP your own way. + Must return a model and list of optimizers. Args: - amp: pointer to amp library object - model: pointer to current lightningModule - optimizers: list of optimizers passed in configure_optimizers() + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. amp_level: AMP mode chosen ('O1', 'O2', etc...) Return: @@ -988,18 +997,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. - Return: any of these 6 options: + Return: + Any of these 6 options. + - Single optimizer. - List or Tuple - List of optimizers. - Two lists - The first list has multiple optimizers, the second a list of LR schedulers. - - Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key. - - Tuple of dictionaries as described, with an optional `frequency` key. + - Dictionary, with an 'optimizer' key and (optionally) a 'lr_scheduler' key. + - Tuple of dictionaries as described, with an optional 'frequency' key. - None - Fit will run without any optimizer. Note: - The `frequency` value is an int corresponding to the number of sequential batches + The 'frequency' value is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. - There is difference between passing multiple optimizers in a list, + There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1: In the former case, all optimizers will operate on the given batch in each optimization step. In the latter, only one optimizer will operate on the given batch at every step. @@ -1012,20 +1023,20 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): opt = Adam(self.parameters(), lr=1e-3) return opt - # multiple optimizer case (eg: GAN) + # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) return generator_opt, disriminator_opt - # example with learning_rate schedulers + # example with learning rate schedulers def configure_optimizers(self): generator_opt = Adam(self.model_gen.parameters(), lr=0.01) disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) return [generator_opt, disriminator_opt], [discriminator_sched] - # example with step-based learning_rate schedulers + # example with step-based learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_disc.parameters(), lr=0.02) @@ -1056,18 +1067,18 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers for you. - - If you use multiple optimizers, training_step will have an additional + - If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter. - - If you use LBFGS lightning handles the closure function automatically for you + - If you use LBFGS Lightning handles the closure function automatically for you. - If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step. - If you need to control how often those optimizers step or override the - default .step() schedule, override the `optimizer_step` hook. + default ``.step()`` schedule, override the :meth:`optimizer_step` hook. - - If you only want to call a learning rate scheduler every `x` step or epoch, + - If you only want to call a learning rate scheduler every ``x`` step or epoch, or want to monitor a custom metric, you can specify these in a dictionary: .. code-block:: python @@ -1091,16 +1102,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): second_order_closure: Optional[Callable] = None, ) -> None: r""" - - Override this method to adjust the default way the Trainer calls each optimizer. - By default, Lightning calls .step() and zero_grad() as shown in the example + Override this method to adjust the default way the + :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer. + By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer - optimizer_idx: If you used multiple optimizers this indexes into that list + optimizer_idx: If you used multiple optimizers this indexes into that list. second_order_closure: closure for second order methods Examples: @@ -1112,7 +1123,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): optimizer.step() optimizer.zero_grad() - # Alternating schedule for optimizer steps (ie: GANs) + # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): # update generator opt every 2 steps @@ -1132,7 +1143,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): Here's another example showing how to use this for more advanced things such as - learning-rate warm-up: + learning rate warm-up: .. code-block:: python @@ -1162,17 +1173,16 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" - When using truncated backpropagation through time, each batch must be split along the - time dimension. Lightning handles this by default, but for custom behavior override + time dimension. Lightning handles this by default, but for custom behavior override this function. Args: batch: Current batch - split_size: How big the split is + split_size: The size of the split Return: - list of batch splits. Each split will be passed to forward_step to enable truncated + List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length. @@ -1197,8 +1207,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): return splits - .. note:: Called in the training loop after on_batch_start if ``truncated_bptt_steps > 0``. - Each returned batch split is passed separately to ``training_step(...)``. + Note: + Called in the training loop after + :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` + if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0. + Each returned batch split is passed separately to :meth:`training_step`. """ time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))] @@ -1223,13 +1236,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): return splits def prepare_data(self) -> None: - """Use this to download and prepare data. - In distributed (GPU, TPU), this will only be called once - - Return: - PyTorch DataLoader - - This is called before requesting the dataloaders + """ + Use this to download and prepare data. + In distributed (GPU, TPU), this will only be called once. + This is called before requesting the dataloaders: .. code-block:: python @@ -1248,23 +1258,25 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): """ def train_dataloader(self) -> DataLoader: - """Implement a PyTorch DataLoader + """ + Implement a PyTorch DataLoader for training. Return: - PyTorch DataLoader + Single PyTorch :class:`~torch.utils.data.DataLoader`. - Return a dataloader. It will not be called every epoch unless you set - ```Trainer(reload_dataloaders_every_epoch=True)```. + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. - It's recommended that all data downloads and preparation happen in prepare_data(). + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - .. note:: Lightning adds the correct sampler for distributed and arbitrary hardware. - No need to set yourself. + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` - - .fit() - - ... - - prepare_data() - - train_dataloader + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. Example: .. code-block:: python @@ -1285,10 +1297,9 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') def tng_dataloader(self): # todo: remove in v1.0.0 - """Implement a PyTorch DataLoader. - + """ Warnings: - Deprecated in v0.5.0. use train_dataloader instead. Will be removed 1.0.0 + Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0. """ output = self.train_dataloader() warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." @@ -1297,24 +1308,26 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" + Implement one or multiple PyTorch DataLoaders for testing. - Return a dataloader. It will not be called every epoch unless you set - ```Trainer(reload_dataloaders_every_epoch=True)```. + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. - It's recommended that all data downloads and preparation happen in prepare_data(). + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - .fit() - - ... - - prepare_data() - - train_dataloader - - val_dataloader - - test_dataloader + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` - .. note:: Lightning adds the correct sampler for distributed and arbitrary hardware. - No need to set yourself. + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware. + There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoader + Single or multiple PyTorch DataLoaders. Example: .. code-block:: python @@ -1332,30 +1345,34 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): return loader - .. note:: If you don't need a test dataset and a test_step, you don't need to implement + Note: + If you don't need a test dataset and a :meth:`test_step`, you don't need to implement this method. """ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" + Implement one or multiple PyTorch DataLoaders for validation. - Return a dataloader. It will not be called every epoch unless you set - ```Trainer(reload_dataloaders_every_epoch=True)```. + The dataloader you return will not be called every epoch unless you set + :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. - It's recommended that all data downloads and preparation happen in prepare_data(). + It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - .fit() - - ... - - prepare_data() - - train_dataloader - - val_dataloader + - :meth:`~pytorch_lightning.trainer.Trainer.fit` + - ... + - :meth:`prepare_data` + - :meth:`train_dataloader` + - :meth:`val_dataloader` + - :meth:`test_dataloader` - .. note:: Lightning adds the correct sampler for distributed and arbitrary hardware - No need to set yourself. + Note: + Lightning adds the correct sampler for distributed and arbitrary hardware + There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoader + Single or multiple PyTorch DataLoaders. Examples: .. code-block:: python @@ -1377,38 +1394,21 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def val_dataloader(self): return [loader_a, loader_b, ..., loader_n] - .. code-block:: python - - def val_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - dataset = MNIST(root='/path/to/mnist/', train=False, - transform=transform, download=True) - loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.hparams.batch_size, - shuffle=True - ) - - return loader - - # can also return multiple dataloaders - def val_dataloader(self): - return [loader_a, loader_b, ..., loader_n] - - .. note:: If you don't need a validation dataset and a validation_step, you don't need to + Note: + If you don't need a validation dataset and a :meth:`validation_step`, you don't need to implement this method. - .. note:: In the case where you return multiple `val_dataloaders`, the `validation_step` - will have an argument `dataset_idx` which matches the order here. + Note: + In the case where you return multiple validation dataloaders, the :meth:`validation_step` + will have an argument ``dataset_idx`` which matches the order here. """ @classmethod def load_from_metrics(cls, weights_path, tags_csv, map_location=None): r""" Warning: - Deprecated in version 0.7.0. You should use `load_from_checkpoint` instead. - Will be removed in v0.9.0. + Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead. + Will be removed in v0.9.0. """ warnings.warn( "`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0." @@ -1425,11 +1425,10 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): *args, **kwargs ) -> 'LightningModule': r""" - - Primary way of loading model from a checkpoint. When Lightning saves a checkpoint - it stores the hyperparameters in the checkpoint if you initialized your LightningModule - with an argument called `hparams` which is a Namespace (output of using argparse - to parse command line arguments). + Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint + it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule` + with an argument called ``hparams`` which is a :class:`~argparse.Namespace` + (output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments). Example: .. code-block:: python @@ -1449,8 +1448,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. - The behaviour is the same as in - `torch.load `_. + The behaviour is the same as in :func:`torch.load`. tags_csv: Optional path to a .csv file with two columns (key, value) as in this example:: @@ -1462,11 +1460,11 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .csv file with the hparams you'd like to use. - These will be converted into a argparse.Namespace and passed into your - LightningModule for use. + These will be converted into a :class:`~argparse.Namespace` and passed into your + :class:`LightningModule` for use. Return: - LightningModule with loaded weights and hyperparameters (if available). + :class:`LightningModule` with loaded weights and hyperparameters (if available). Example: .. code-block:: python @@ -1558,7 +1556,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def freeze(self) -> None: r""" - Freeze all params for inference + Freeze all params for inference. Example: .. code-block:: python @@ -1573,7 +1571,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): self.eval() def unfreeze(self) -> None: - """Unfreeze all params for training. + """ + Unfreeze all parameters for training. .. code-block:: python @@ -1588,8 +1587,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" - Called by lightning to restore your model. - If you saved something with **on_save_checkpoint** this is your chance to restore this. + Called by Lightning to restore your model. + If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. Args: checkpoint: Loaded checkpoint @@ -1602,15 +1601,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save'] - .. note:: Lighting auto-restores global step, epoch, and train state including amp scaling. - No need for you to restore anything regarding training. + Note: + Lightning auto-restores global step, epoch, and train state including amp scaling. + There is no need for you to restore anything regarding training. """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" - - Called by lightning when saving a checkpoint to give you a chance to store anything - else you might want to save + Called by Lightning when saving a checkpoint to give you a chance to store anything + else you might want to save. Args: checkpoint: Checkpoint to be saved @@ -1618,13 +1617,15 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): Example: .. code-block:: python + def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object - .. note:: Lighting saves all aspects of training (epoch, global step, etc...) - including amp scaling. No need - for you to store anything about training. + Note: + Lightning saves all aspects of training (epoch, global step, etc...) + including amp scaling. + There is no need for you to store anything about training. """