Fix hiddens type annotation (#9377)
This commit is contained in:
parent
41ba639859
commit
3070a9ea6e
|
@ -173,15 +173,7 @@ class Accelerator:
|
|||
def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
|
||||
"""The actual training step.
|
||||
|
||||
Args:
|
||||
step_kwargs: the arguments for the models training step. Can consist of the following:
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): Integer displaying index of this batch
|
||||
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
- hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.train_step_context():
|
||||
return self.training_type_plugin.training_step(*step_kwargs.values())
|
||||
|
@ -192,14 +184,7 @@ class Accelerator:
|
|||
def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
|
||||
"""The actual validation step.
|
||||
|
||||
Args:
|
||||
step_kwargs: the arguments for the models validation step. Can consist of the following:
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple val dataloaders used)
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.val_step_context():
|
||||
return self.training_type_plugin.validation_step(*step_kwargs.values())
|
||||
|
@ -207,14 +192,7 @@ class Accelerator:
|
|||
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
|
||||
"""The actual test step.
|
||||
|
||||
Args:
|
||||
step_kwargs: the arguments for the models test step. Can consist of the following:
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch.
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple test dataloaders used).
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.test_step_context():
|
||||
return self.training_type_plugin.test_step(*step_kwargs.values())
|
||||
|
@ -222,14 +200,7 @@ class Accelerator:
|
|||
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
|
||||
"""The actual predict step.
|
||||
|
||||
Args:
|
||||
step_kwargs: the arguments for the models predict step. Can consist of the following:
|
||||
|
||||
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
- batch_idx (int): The index of this batch.
|
||||
- dataloader_idx (int): The index of the dataloader that produced this batch
|
||||
(only if multiple predict dataloaders used).
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.predict_step_context():
|
||||
return self.training_type_plugin.predict_step(*step_kwargs.values())
|
||||
|
|
|
@ -620,9 +620,9 @@ class LightningModule(
|
|||
Args:
|
||||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
||||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (int): Integer displaying index of this batch
|
||||
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
||||
hiddens(:class:`~torch.Tensor`): Passed in if
|
||||
batch_idx (``int``): Integer displaying index of this batch
|
||||
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
|
||||
hiddens (``Any``): Passed in if
|
||||
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
|
||||
Return:
|
||||
|
@ -667,9 +667,8 @@ class LightningModule(
|
|||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# hiddens are the hidden states from the previous truncated backprop step
|
||||
...
|
||||
out, hiddens = self.lstm(data, hiddens)
|
||||
...
|
||||
loss = ...
|
||||
return {"loss": loss, "hiddens": hiddens}
|
||||
|
||||
Note:
|
||||
|
@ -1585,7 +1584,7 @@ class LightningModule(
|
|||
"""
|
||||
optimizer.zero_grad()
|
||||
|
||||
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
|
||||
def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]:
|
||||
r"""
|
||||
When using truncated backpropagation through time, each batch must be split along the
|
||||
time dimension. Lightning handles this by default, but for custom behavior override
|
||||
|
@ -1603,29 +1602,25 @@ class LightningModule(
|
|||
Examples::
|
||||
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_x = x[:, t:t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_x = x[:, t:t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
||||
|
||||
batch_split.append(split_x)
|
||||
|
||||
splits.append(batch_split)
|
||||
|
||||
return splits
|
||||
batch_split.append(split_x)
|
||||
splits.append(batch_split)
|
||||
return splits
|
||||
|
||||
Note:
|
||||
Called in the training loop after
|
||||
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
|
||||
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
||||
Each returned batch split is passed separately to :meth:`training_step`.
|
||||
|
||||
"""
|
||||
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
||||
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
||||
|
|
|
@ -16,7 +16,6 @@ from contextlib import contextmanager
|
|||
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -118,7 +117,7 @@ def _build_training_step_kwargs(
|
|||
batch: Any,
|
||||
batch_idx: int,
|
||||
opt_idx: Optional[int],
|
||||
hiddens: Optional[Tensor],
|
||||
hiddens: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Builds the keyword arguments for training_step.
|
||||
|
||||
|
|
Loading…
Reference in New Issue