Fix double precision support in Lite (#14827)
This commit is contained in:
parent
b0bd949d58
commit
d572a7e2ec
|
@ -153,6 +153,8 @@ class LightningLite(ABC):
|
||||||
self._validate_setup(model, optimizers)
|
self._validate_setup(model, optimizers)
|
||||||
original_model = model
|
original_model = model
|
||||||
|
|
||||||
|
model = self._precision_plugin.convert_module(model)
|
||||||
|
|
||||||
if move_to_device:
|
if move_to_device:
|
||||||
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
|
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from contextlib import contextmanager
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
from lightning_lite.plugins.precision.precision import Precision
|
from lightning_lite.plugins.precision.precision import Precision
|
||||||
|
|
||||||
|
@ -24,6 +25,9 @@ class DoublePrecision(Precision):
|
||||||
|
|
||||||
precision: int = 64
|
precision: int = 64
|
||||||
|
|
||||||
|
def convert_module(self, module: Module) -> Module:
|
||||||
|
return module.double()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def forward_context(self) -> Generator[None, None, None]:
|
def forward_context(self) -> Generator[None, None, None]:
|
||||||
"""A context manager to change the default tensor type.
|
"""A context manager to change the default tensor type.
|
||||||
|
|
|
@ -29,6 +29,13 @@ class Precision:
|
||||||
|
|
||||||
precision: Union[str, int] = 32
|
precision: Union[str, int] = 32
|
||||||
|
|
||||||
|
def convert_module(self, module: Module) -> Module:
|
||||||
|
"""Convert the module parameters to the precision type this plugin handles.
|
||||||
|
|
||||||
|
This is optional and depends on the precision limitations during optimization.
|
||||||
|
"""
|
||||||
|
return module
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def forward_context(self) -> Generator[None, None, None]:
|
def forward_context(self) -> Generator[None, None, None]:
|
||||||
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
|
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
|
||||||
|
|
|
@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
|
- The `MLFlowLogger.finalize()` now sets the status to `FAILED` when an exception occurred in `Trainer`, and sets the status to `FINISHED` on successful completion ([#12292](https://github.com/Lightning-AI/lightning/pull/12292))
|
||||||
|
|
||||||
|
|
||||||
|
- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827))
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
|
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
|
||||||
|
|
|
@ -36,7 +36,6 @@ class DoublePrecisionBoringLite(BoringLite):
|
||||||
return BoringDoubleModule()
|
return BoringDoubleModule()
|
||||||
|
|
||||||
def step(self, model, batch):
|
def step(self, model, batch):
|
||||||
model.double() # TODO(lite): this needs to be done automatically in Lite.setup()
|
|
||||||
assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64
|
assert model.layer.weight.dtype == model.layer.bias.dtype == torch.float64
|
||||||
assert model.complex_buffer.dtype == torch.complex64
|
assert model.complex_buffer.dtype == torch.complex64
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue