[sharded plugin] Fix check for fp16 precision (#7825)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
shuyingsunshine21 2021-06-03 23:34:39 -07:00 committed by GitHub
parent f34584001c
commit ca89a7f344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
## [1.3.2] - 2021-05-18
### Changed

View File

@ -54,7 +54,8 @@ class DDPShardedPlugin(DDPPlugin):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
is_fp16 = self.lightning_module.trainer.precision == 16
precision = self.lightning_module.trainer.precision
is_fp16 = precision in ("mixed", 16)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.