Fix BF16 teardown for TPU precision plugin (#10990)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-12-22 04:47:14 +01:00 committed by GitHub
parent 235efb37d7
commit ba8e7cd787
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 54 additions and 14 deletions

View File

@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))
### Changed
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
## [1.5.7] - 2021-12-21

View File

@ -236,3 +236,9 @@ class PrecisionPlugin(CheckpointHooks):
"""A contextmanager for the predict step."""
with self.forward_context():
yield
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""

View File

@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_USE_BF16"] = "1"
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)

View File

@ -86,6 +86,7 @@ class SingleDeviceStrategy(Strategy):
return obj
def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()

View File

@ -74,6 +74,7 @@ class SingleTPUStrategy(SingleDeviceStrategy):
self.model.to(self.root_device)
def teardown(self) -> None:
super().teardown()
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)

View File

@ -244,9 +244,6 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
}
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@ -340,6 +337,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
return xm.all_gather(tensor)
def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None)
@classmethod

View File

@ -437,13 +437,13 @@ class Strategy(ABC):
"""
yield
@abstractmethod
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
self._move_optimizer_state(torch.device("cpu"))
self.precision_plugin.teardown()
@classmethod
def register_plugins(cls, plugin_registry) -> None:

View File

@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
@pytest.mark.parametrize("tpu_core", [1, 5])
@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
@RunIf(tpu=True)

View File

View File

@ -0,0 +1,25 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin
def test_teardown():
plugin = TPUBf16PrecisionPlugin()
plugin.connect(Mock(), Mock(), Mock())
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ