diff --git a/src/lightning_lite/plugins/precision/utils.py b/src/lightning_lite/plugins/precision/utils.py index f9af7de5ba..f607755e5d 100644 --- a/src/lightning_lite/plugins/precision/utils.py +++ b/src/lightning_lite/plugins/precision/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch @@ -18,10 +19,12 @@ from lightning_lite.utilities.enums import PrecisionType def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: - if torch.is_floating_point(tensor): - if precision == PrecisionType.HALF: - return tensor.half() - if precision == PrecisionType.BFLOAT: - return tensor.bfloat16() - + if precision == PrecisionType.HALF: + return _convert_fp_tensor(tensor, torch.half) + if precision == PrecisionType.BFLOAT: + return _convert_fp_tensor(tensor, torch.bfloat16) return tensor + + +def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor: + return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor diff --git a/src/lightning_lite/utilities/device_dtype_mixin.py b/src/lightning_lite/utilities/device_dtype_mixin.py index b889288ea5..1f5164f0cd 100644 --- a/src/lightning_lite/utilities/device_dtype_mixin.py +++ b/src/lightning_lite/utilities/device_dtype_mixin.py @@ -47,67 +47,10 @@ class _DeviceDtypeModuleMixin(Module): return device def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type] - """Moves and/or casts the parameters and buffers. - - This can be called as - .. function:: to(device=None, dtype=None, non_blocking=False) - .. function:: to(dtype, non_blocking=False) - .. function:: to(tensor, non_blocking=False) - Its signature is similar to :meth:`torch.Tensor.to`, but only accepts - floating point desired :attr:`dtype` s. In addition, this method will - only cast the floating point parameters and buffers to :attr:`dtype` - (if given). The integral parameters and buffers will be moved - :attr:`device`, if that is given, but with dtypes unchanged. When - :attr:`non_blocking` is set, it tries to convert/move asynchronously - with respect to the host if possible, e.g., moving CPU Tensors with - pinned memory to CUDA devices. - See below for examples. - - Note: - This method modifies the module in-place. - - Args: - device: the desired device of the parameters - and buffers in this module - dtype: the desired floating point type of - the floating point parameters and buffers in this module - tensor: Tensor whose dtype and device are the desired - dtype and device for all parameters and buffers in this module - - Returns: - Module: self - - Example:: - >>> from torch import Tensor - >>> class ExampleModule(_DeviceDtypeModuleMixin): - ... def __init__(self, weight: Tensor): - ... super().__init__() - ... self.register_buffer('weight', weight) - >>> _ = torch.manual_seed(0) - >>> module = ExampleModule(torch.rand(3, 4)) - >>> module.weight #doctest: +ELLIPSIS - tensor([[...]]) - >>> module.to(torch.double) - ExampleModule() - >>> module.weight #doctest: +ELLIPSIS - tensor([[...]], dtype=torch.float64) - >>> cpu = torch.device('cpu') - >>> module.to(cpu, dtype=torch.half, non_blocking=True) - ExampleModule() - >>> module.weight #doctest: +ELLIPSIS - tensor([[...]], dtype=torch.float16) - >>> module.to(cpu) - ExampleModule() - >>> module.weight #doctest: +ELLIPSIS - tensor([[...]], dtype=torch.float16) - >>> module.device - device(type='cpu') - >>> module.dtype - torch.float16 - """ - # there is diff nb vars in PT 1.5 - out = torch._C._nn._parse_to(*args, **kwargs) - self.__update_properties(device=out[0], dtype=out[1]) + """See :meth:`torch.nn.Module.to`.""" + # this converts `str` device to `torch.device` + device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2] + self.__update_properties(device=device, dtype=dtype) return super().to(*args, **kwargs) def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type] @@ -130,50 +73,27 @@ class _DeviceDtypeModuleMixin(Module): return super().cuda(device=device) def cpu(self) -> Self: # type: ignore[valid-type] - """Moves all model parameters and buffers to the CPU. - - Returns: - Module: self - """ + """See :meth:`torch.nn.Module.cpu`.""" self.__update_properties(device=torch.device("cpu")) return super().cpu() def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type] - """Casts all parameters and buffers to :attr:`dst_type`. - - Arguments: - dst_type (type or string): the desired type - - Returns: - Module: self - """ + """See :meth:`torch.nn.Module.type`.""" self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) def float(self) -> Self: # type: ignore[valid-type] - """Casts all floating point parameters and buffers to ``float`` datatype. - - Returns: - Module: self - """ + """See :meth:`torch.nn.Module.float`.""" self.__update_properties(dtype=torch.float) return super().float() def double(self) -> Self: # type: ignore[valid-type] - """Casts all floating point parameters and buffers to ``double`` datatype. - - Returns: - Module: self - """ + """See :meth:`torch.nn.Module.double`.""" self.__update_properties(dtype=torch.double) return super().double() def half(self) -> Self: # type: ignore[valid-type] - """Casts all floating point parameters and buffers to ``half`` datatype. - - Returns: - Module: self - """ + """See :meth:`torch.nn.Module.half`.""" self.__update_properties(dtype=torch.half) return super().half() diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index 651d80810d..a976edb09a 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -22,6 +22,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from lightning_lite.plugins import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.strategies import Strategy from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -104,18 +105,16 @@ class _LiteModule(_DeviceDtypeModuleMixin): 64: torch.float64, } # TODO: let the precision plugin handle the conversion - to_type = precision_to_type[precision] - - def _convert_float_tensor(t: Tensor) -> Tensor: - return t.to(to_type) if torch.is_floating_point(t) else t - - args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor) + args, kwargs = apply_to_collection( + [args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision] + ) with self._precision_plugin.forward_context(): output = self._forward_module(*args, **kwargs) - to_type = torch.get_default_dtype() - output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor) + output = apply_to_collection( + output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype() + ) return output @overload diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 54a8fd64cf..4bf5694a2c 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -602,7 +602,7 @@ class LightningModule( def forward(self, *args: Any, **kwargs: Any) -> Any: r""" - Same as :meth:`torch.nn.Module.forward()`. + Same as :meth:`torch.nn.Module.forward`. Args: *args: Whatever you decide to pass into the forward method. diff --git a/src/pytorch_lightning/plugins/precision/double.py b/src/pytorch_lightning/plugins/precision/double.py index cff3b2619f..43c0610dfb 100644 --- a/src/pytorch_lightning/plugins/precision/double.py +++ b/src/pytorch_lightning/plugins/precision/double.py @@ -21,6 +21,7 @@ from torch import FloatTensor, Tensor from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin @@ -33,15 +34,9 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): pl_module: the model to wrap """ - @staticmethod - def _to_double_precision(data: Tensor) -> Tensor: - if data.is_floating_point(): - return data.double() - return data - @staticmethod def _move_float_tensors_to_double(collection: Any) -> Any: - return apply_to_collection(collection, Tensor, LightningDoublePrecisionModule._to_double_precision) + return apply_to_collection(collection, Tensor, function=_convert_fp_tensor, dst_type=torch.double) def training_step(self, *args: Any, **kwargs: Any) -> Any: return self.module.training_step( diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 2596a6fa19..6e91cc3d54 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -85,17 +85,9 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase): self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: - inputs = apply_to_collection(inputs, Tensor, function=self._batch_to) + inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision) return super().forward(*inputs, **kwargs) - def _batch_to(self, batch: Tensor) -> Tensor: - if torch.is_floating_point(batch): - if self.precision == PrecisionType.HALF: - return batch.half() - elif self.precision == PrecisionType.BFLOAT: - return batch.bfloat16() - return batch - class DeepSpeedStrategy(DDPStrategy): strategy_name = "deepspeed" diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index a363d143e5..b95cb4de57 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -17,14 +17,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from lightning_utilities.core.apply_func import apply_to_collection -from torch import FloatTensor, Tensor +from torch import Tensor from torch.utils.data import DataLoader, Sampler import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.utilities.cloud_io import get_filesystem -from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy @@ -61,21 +60,9 @@ class LightningIPUModule(_LightningModuleWrapperBase): self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: - if self.precision == PrecisionType.HALF: - inputs = self._move_float_tensors_to_half(inputs) - + inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision) return super().forward(*inputs, **kwargs) - @staticmethod - def batch_to(data: Tensor) -> Tensor: - if torch.is_floating_point(data): - return data.half() - return data - - def _move_float_tensors_to_half(self, batch: Any) -> Any: - batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) - return batch - class IPUStrategy(ParallelStrategy): """Plugin for training on IPU devices.""" diff --git a/tests/tests_lite/utilities/test_device_dtype_mixin.py b/tests/tests_lite/utilities/test_device_dtype_mixin.py index 35caf08148..daa81a26a5 100644 --- a/tests/tests_lite/utilities/test_device_dtype_mixin.py +++ b/tests/tests_lite/utilities/test_device_dtype_mixin.py @@ -23,7 +23,7 @@ class TopModule(_DeviceDtypeModuleMixin): @pytest.mark.parametrize( - "dst_device_str,dst_dtype", + "dst_device_str,dst_type", [ ("cpu", torch.half), ("cpu", torch.float), @@ -35,21 +35,19 @@ class TopModule(_DeviceDtypeModuleMixin): ], ) @RunIf(min_cuda_gpus=1) -def test_submodules_device_and_dtype(dst_device_str, dst_dtype): +def test_submodules_device_and_dtype(dst_device_str, dst_type): """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).""" - dst_device = torch.device(dst_device_str) - model = TopModule() assert model.device == torch.device("cpu") - model = model.to(device=dst_device, dtype=dst_dtype) + model = model.to(device=dst_device, dtype=dst_type) # nn.Module does not have these attributes assert not hasattr(model.module, "_device") assert not hasattr(model.module, "_dtype") # device and dtype change should propagate down into all children assert model.device == model.module.module.device == dst_device - assert model.dtype == model.module.module.dtype == dst_dtype + assert model.dtype == model.module.module.dtype == dst_type @pytest.mark.parametrize( @@ -72,6 +70,16 @@ def test_cuda_device(device): assert device.index == torch.cuda.current_device() +@RunIf(min_cuda_gpus=1) +def test_cpu_device(): + model = SubSubModule().cuda() + assert model.device.type == "cuda" + assert model.device.index == 0 + model.cpu() + assert model.device.type == "cpu" + assert model.device.index is None + + @RunIf(min_cuda_gpus=2) def test_cuda_current_device(): """Test that calling .cuda() moves the model to the correct device and respects current cuda device setting.""" @@ -92,3 +100,46 @@ def test_cuda_current_device(): model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model assert model.device == torch.device("cuda", 1) assert model.layer.weight.device == torch.device("cuda", 1) + + +class ExampleModule(_DeviceDtypeModuleMixin): + def __init__(self, weight): + super().__init__() + self.register_buffer("weight", weight) + + +def test_to_combinations(): + module = ExampleModule(torch.rand(3, 4)) + # sanity check + assert module.weight.shape == (3, 4) + assert module.weight.dtype is torch.float32 + # positional dtype + module.to(torch.double) + assert module.weight.dtype is torch.float64 + # positional device + module.to("cpu", dtype=torch.half, non_blocking=True) + assert module.weight.dtype is torch.float16 + assert module.device == torch.device("cpu") + assert module.dtype is torch.float16 + + +def test_dtype_conversions(): + module = ExampleModule(torch.tensor(1)) + # different dtypes + assert module.weight.dtype is torch.int64 + assert module.dtype is torch.float32 + # `.double()` skips non floating points + module.double() + assert module.weight.dtype is torch.int64 + assert module.dtype is torch.float64 + # but `type` doesn't + module.type(torch.float) + assert module.weight.dtype is torch.float32 + assert module.dtype is torch.float32 + # now, test the rest + module.float() + assert module.weight.dtype is torch.float32 + assert module.dtype is torch.float32 + module.half() + assert module.weight.dtype is torch.float16 + assert module.dtype is torch.float16 diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 2e42803d17..038058b446 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -83,38 +83,20 @@ class ModelParallelBoringModelManualOptim(BoringModel): return False -def test_deepspeed_lightning_module(): - """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.""" - - model = BoringModel() - with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"): - module = LightningDeepSpeedModule(model, precision=16) - - module.half() - assert module.dtype == torch.half - assert model.dtype == torch.half - - module.to(torch.double) - assert module.dtype == torch.double - assert model.dtype == torch.double - - @RunIf(min_cuda_gpus=1) def test_deepspeed_lightning_module_precision(): """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16.""" - model = BoringModel() with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"): module = LightningDeepSpeedModule(model, precision=16) - module.cuda().half() + module.to(device="cuda", dtype=torch.half) assert module.dtype == torch.half assert model.dtype == torch.half - x = torch.randn((1, 32), dtype=torch.float).cuda() + x = torch.randn((1, 32), device="cuda", dtype=torch.float) out = module(x) - assert out.dtype == torch.half module.to(torch.double)