[sharded plugin] Fix check for fp16 precision (#7825)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
f34584001c
commit
ca89a7f344
|
@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
|
||||
|
||||
|
||||
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
|
||||
|
||||
|
||||
## [1.3.2] - 2021-05-18
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -54,7 +54,8 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||
is_fp16 = self.lightning_module.trainer.precision == 16
|
||||
precision = self.lightning_module.trainer.precision
|
||||
is_fp16 = precision in ("mixed", 16)
|
||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||
# improves performance. When using PyTorch AMP, it will not degrade
|
||||
# the model performance.
|
||||
|
|
Loading…
Reference in New Issue