lightning/pytorch_lightning/lite/wrappers.py

141 lines
5.3 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Generator, Iterator, Optional, Union
import torch
from torch import nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.strategies import Strategy
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
def _do_nothing_closure() -> None:
return None
class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the strategy plugin.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
strategy: Reference to the strategy for handling the optimizer step
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_LiteOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
@property
def optimizer(self) -> Optimizer:
return self._optimizer
def state_dict(self) -> Dict[str, Tensor]:
return self._strategy.optimizer_state(self.optimizer)
def step(self, closure: Optional[Callable] = None) -> Any:
closure = closure or _do_nothing_closure
return self._strategy.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
model=self._strategy.model,
)
class _LiteModule(DeviceDtypeModuleMixin):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
The underlying wrapped module can be accessed via the property :attr:`module`.
Args:
module: The module to wrap
precision_plugin: Reference to the precision plugin for handling precision context
"""
super().__init__()
self._module = module
self._precision_plugin = precision_plugin
@property
def module(self) -> nn.Module:
return self._module
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
method."""
precision = self._precision_plugin.precision
precision_to_type = {
"bf16": torch.bfloat16,
16: torch.float16,
32: torch.float32,
64: torch.float64,
}
# TODO (@awaelchli): let the precision plugin handle the conversion
to_type = precision_to_type[precision]
def _convert_float_tensor(t: Tensor) -> Tensor:
return t.to(to_type) if torch.is_floating_point(t) else t
args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
with self._precision_plugin.forward_context():
output = self.module(*args, **kwargs)
to_type = torch.get_default_dtype()
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
return output
class _LiteDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
device automatically if the device is specified.
Args:
dataloader: The dataloader to wrap
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
"""
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
@property
def device(self) -> Optional[torch.device]:
return self._device
def __len__(self) -> int:
return len(self._dataloader)
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
return
for item in iterator:
yield move_data_to_device(item, self._device)