From 3070a9ea6edb9648152218a79620f035d9bf719c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 9 Sep 2021 09:45:52 +0200 Subject: [PATCH] Fix hiddens type annotation (#9377) --- pytorch_lightning/accelerators/accelerator.py | 37 ++---------------- pytorch_lightning/core/lightning.py | 39 ++++++++----------- pytorch_lightning/loops/utilities.py | 3 +- 3 files changed, 22 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f40dc9e157..93915ac946 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -173,15 +173,7 @@ class Accelerator: def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual training step. - Args: - step_kwargs: the arguments for the models training step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): Integer displaying index of this batch - - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. + See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ with self.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) @@ -192,14 +184,7 @@ class Accelerator: def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. - Args: - step_kwargs: the arguments for the models validation step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple val dataloaders used) + See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ with self.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) @@ -207,14 +192,7 @@ class Accelerator: def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual test step. - Args: - step_kwargs: the arguments for the models test step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ with self.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) @@ -222,14 +200,7 @@ class Accelerator: def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: """The actual predict step. - Args: - step_kwargs: the arguments for the models predict step. Can consist of the following: - - - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - - batch_idx (int): The index of this batch. - - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple predict dataloaders used). + See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ with self.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e3c7402242..1b805eefdb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -620,9 +620,9 @@ class LightningModule( Args: batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if + batch_idx (``int``): Integer displaying index of this batch + optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. + hiddens (``Any``): Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Return: @@ -667,9 +667,8 @@ class LightningModule( # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step - ... out, hiddens = self.lstm(data, hiddens) - ... + loss = ... return {"loss": loss, "hiddens": hiddens} Note: @@ -1585,7 +1584,7 @@ class LightningModule( """ optimizer.zero_grad() - def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: + def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]: r""" When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override @@ -1603,29 +1602,25 @@ class LightningModule( Examples:: def tbptt_split_batch(self, batch, split_size): - splits = [] - for t in range(0, time_dims[0], split_size): - batch_split = [] - for i, x in enumerate(batch): - if isinstance(x, torch.Tensor): - split_x = x[:, t:t + split_size] - elif isinstance(x, collections.Sequence): - split_x = [None] * len(x) - for batch_idx in range(len(x)): + splits = [] + for t in range(0, time_dims[0], split_size): + batch_split = [] + for i, x in enumerate(batch): + if isinstance(x, torch.Tensor): + split_x = x[:, t:t + split_size] + elif isinstance(x, collections.Sequence): + split_x = [None] * len(x) + for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] - - batch_split.append(split_x) - - splits.append(batch_split) - - return splits + batch_split.append(split_x) + splits.append(batch_split) + return splits Note: Called in the training loop after :meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start` if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. Each returned batch split is passed separately to :meth:`training_step`. - """ time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))] assert len(time_dims) >= 1, "Unable to determine batch time dimension" diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 154680535e..f74f973d6b 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -16,7 +16,6 @@ from contextlib import contextmanager from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple import torch -from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -118,7 +117,7 @@ def _build_training_step_kwargs( batch: Any, batch_idx: int, opt_idx: Optional[int], - hiddens: Optional[Tensor], + hiddens: Optional[Any], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step.