Fix double precision support in Lite (#14827)
This commit is contained in:
parent
b0bd949d58
commit
d572a7e2ec
|
@ -153,6 +153,8 @@ class LightningLite(ABC):
|
|||
self._validate_setup(model, optimizers)
|
||||
original_model = model
|
||||
|
||||
model = self._precision_plugin.convert_module(model)
|
||||
|
||||
if move_to_device:
|
||||
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from contextlib import contextmanager
|
|||
from typing import Generator
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from lightning_lite.plugins.precision.precision import Precision
|
||||
|
||||
|
@ -24,6 +25,9 @@ class DoublePrecision(Precision):
|
|||
|
||||
precision: int = 64
|
||||
|
||||
def convert_module(self, module: Module) -> Module:
|
||||
return module.double()
|
||||
|
||||
@contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
|
|
@ -29,6 +29,13 @@ class Precision:
|
|||
|
||||
precision: Union[str, int] = 32
|
||||
|
||||
def convert_module(self, module: Module) -> Module:
|
||||
"""Convert the module parameters to the precision type this plugin handles.
|
||||
|
||||
This is optional and depends on the precision limitations during optimization.
|
||||
"""
|
||||
return module
|
||||
|
||||
@contextlib.contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
|
||||
|
|
|
@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
|
||||
|
||||
|
||||
- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
|
||||
|
|
|
@ -36,7 +36,6 @@ class DoublePrecisionBoringLite(BoringLite):
|
|||
return BoringDoubleModule()
|
||||
|
||||
def step(self, model, batch):
|
||||
model.double() # TODO(lite): this needs to be done automatically in Lite.setup()
|
||||
assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64
|
||||
assert model.complex_buffer.dtype == torch.complex64
|
||||
|
||||
|
|
Loading…
Reference in New Issue