Fix hiddens type annotation (#9377)

This commit is contained in:
Carlos Mocholí 2021-09-09 09:45:52 +02:00 committed by GitHub
parent 41ba639859
commit 3070a9ea6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 57 deletions

View File

@ -173,15 +173,7 @@ class Accelerator:
def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
"""The actual training step.
Args:
step_kwargs: the arguments for the models training step. Can consist of the following:
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): Integer displaying index of this batch
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
"""
with self.precision_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())
@ -192,14 +184,7 @@ class Accelerator:
def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual validation step.
Args:
step_kwargs: the arguments for the models validation step. Can consist of the following:
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())
@ -207,14 +192,7 @@ class Accelerator:
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
"""The actual test step.
Args:
step_kwargs: the arguments for the models test step. Can consist of the following:
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch.
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())
@ -222,14 +200,7 @@ class Accelerator:
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
"""The actual predict step.
Args:
step_kwargs: the arguments for the models predict step. Can consist of the following:
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
- batch_idx (int): The index of this batch.
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())

View File

@ -620,9 +620,9 @@ class LightningModule(
Args:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(:class:`~torch.Tensor`): Passed in if
batch_idx (``int``): Integer displaying index of this batch
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
hiddens (``Any``): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Return:
@ -667,9 +667,8 @@ class LightningModule(
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hidden states from the previous truncated backprop step
...
out, hiddens = self.lstm(data, hiddens)
...
loss = ...
return {"loss": loss, "hiddens": hiddens}
Note:
@ -1585,7 +1584,7 @@ class LightningModule(
"""
optimizer.zero_grad()
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]:
r"""
When using truncated backpropagation through time, each batch must be split along the
time dimension. Lightning handles this by default, but for custom behavior override
@ -1603,29 +1602,25 @@ class LightningModule(
Examples::
def tbptt_split_batch(self, batch, split_size):
splits = []
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t:t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
splits = []
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t:t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t:t + split_size]
batch_split.append(split_x)
splits.append(batch_split)
return splits
batch_split.append(split_x)
splits.append(batch_split)
return splits
Note:
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
Each returned batch split is passed separately to :meth:`training_step`.
"""
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
assert len(time_dims) >= 1, "Unable to determine batch time dimension"

View File

@ -16,7 +16,6 @@ from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple
import torch
from torch import Tensor
from torch.optim import Optimizer
import pytorch_lightning as pl
@ -118,7 +117,7 @@ def _build_training_step_kwargs(
batch: Any,
batch_idx: int,
opt_idx: Optional[int],
hiddens: Optional[Tensor],
hiddens: Optional[Any],
) -> Dict[str, Any]:
"""Builds the keyword arguments for training_step.