diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 4a1f4e8004..f148632d63 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -153,6 +153,8 @@ class LightningLite(ABC): self._validate_setup(model, optimizers) original_model = model + model = self._precision_plugin.convert_module(model) + if move_to_device: model = self._move_model_to_device(model=model, optimizers=list(optimizers)) diff --git a/src/lightning_lite/plugins/precision/double.py b/src/lightning_lite/plugins/precision/double.py index 3de2b422f8..9bbb033b74 100644 --- a/src/lightning_lite/plugins/precision/double.py +++ b/src/lightning_lite/plugins/precision/double.py @@ -15,6 +15,7 @@ from contextlib import contextmanager from typing import Generator import torch +from torch.nn import Module from lightning_lite.plugins.precision.precision import Precision @@ -24,6 +25,9 @@ class DoublePrecision(Precision): precision: int = 64 + def convert_module(self, module: Module) -> Module: + return module.double() + @contextmanager def forward_context(self) -> Generator[None, None, None]: """A context manager to change the default tensor type. diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 0fd1a4c4e1..07ce77b889 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -29,6 +29,13 @@ class Precision: precision: Union[str, int] = 32 + def convert_module(self, module: Module) -> Module: + """Convert the module parameters to the precision type this plugin handles. + + This is optional and depends on the precision limitations during optimization. + """ + return module + @contextlib.contextmanager def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6ed9b7f648..fca1366b1c 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292)) +- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827)) + + ### Deprecated - Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) diff --git a/tests/tests_lite/plugins/precision/test_double_integration.py b/tests/tests_lite/plugins/precision/test_double_integration.py index 589cfc90b8..e6ab910347 100644 --- a/tests/tests_lite/plugins/precision/test_double_integration.py +++ b/tests/tests_lite/plugins/precision/test_double_integration.py @@ -36,7 +36,6 @@ class DoublePrecisionBoringLite(BoringLite): return BoringDoubleModule() def step(self, model, batch): - model.double() # TODO(lite): this needs to be done automatically in Lite.setup() assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64 assert model.complex_buffer.dtype == torch.complex64