Fix double precision support in Lite (#14827)

This commit is contained in:
Adrian Wälchli 2022-09-27 10:38:20 +02:00 committed by GitHub
parent b0bd949d58
commit d572a7e2ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 1 deletions

View File

@ -153,6 +153,8 @@ class LightningLite(ABC):
self._validate_setup(model, optimizers)
original_model = model
model = self._precision_plugin.convert_module(model)
if move_to_device:
model = self._move_model_to_device(model=model, optimizers=list(optimizers))

View File

@ -15,6 +15,7 @@ from contextlib import contextmanager
from typing import Generator
import torch
from torch.nn import Module
from lightning_lite.plugins.precision.precision import Precision
@ -24,6 +25,9 @@ class DoublePrecision(Precision):
precision: int = 64
def convert_module(self, module: Module) -> Module:
return module.double()
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.

View File

@ -29,6 +29,13 @@ class Precision:
precision: Union[str, int] = 32
def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
return module
@contextlib.contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""

View File

@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827))
### Deprecated
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))

View File

@ -36,7 +36,6 @@ class DoublePrecisionBoringLite(BoringLite):
return BoringDoubleModule()
def step(self, model, batch):
model.double() # TODO(lite): this needs to be done automatically in Lite.setup()
assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64
assert model.complex_buffer.dtype == torch.complex64