Fix BF16 teardown for TPU precision plugin (#10990)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-12-22 04:47:14 +01:00 committed by GitHub
parent 235efb37d7
commit ba8e7cd787
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 54 additions and 14 deletions

View File

@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) - Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))
### Changed ### Changed
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123)) * Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183)) * Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142)) * Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185)) * Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) * Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143)) * Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) * Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210)) * Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130)) - Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119)) - Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
## [1.5.7] - 2021-12-21 ## [1.5.7] - 2021-12-21

View File

@ -236,3 +236,9 @@ class PrecisionPlugin(CheckpointHooks):
"""A contextmanager for the predict step.""" """A contextmanager for the predict step."""
with self.forward_context(): with self.forward_context():
yield yield
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""

View File

@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
def connect( def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List[Optimizer], List[Any]]: ) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
os.environ["XLA_USE_BF16"] = str(1) os.environ["XLA_USE_BF16"] = "1"
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)

View File

@ -86,6 +86,7 @@ class SingleDeviceStrategy(Strategy):
return obj return obj
def teardown(self) -> None: def teardown(self) -> None:
super().teardown()
if self.on_gpu: if self.on_gpu:
# GPU teardown # GPU teardown
self.lightning_module.cpu() self.lightning_module.cpu()

View File

@ -74,6 +74,7 @@ class SingleTPUStrategy(SingleDeviceStrategy):
self.model.to(self.root_device) self.model.to(self.root_device)
def teardown(self) -> None: def teardown(self) -> None:
super().teardown()
# TPU teardown # TPU teardown
os.environ.pop("PT_XLA_DEBUG", None) os.environ.pop("PT_XLA_DEBUG", None)

View File

@ -244,9 +244,6 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
} }
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
context = mp.get_context(self.start_method or "fork") context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue() return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@ -340,6 +337,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
return xm.all_gather(tensor) return xm.all_gather(tensor)
def teardown(self) -> None: def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None) os.environ.pop("PT_XLA_DEBUG", None)
@classmethod @classmethod

View File

@ -437,13 +437,13 @@ class Strategy(ABC):
""" """
yield yield
@abstractmethod
def teardown(self) -> None: def teardown(self) -> None:
"""This method is called to teardown the training process. """This method is called to teardown the training process.
It is the right place to release memory and free other resources. It is the right place to release memory and free other resources.
""" """
self._move_optimizer_state(torch.device("cpu")) self._move_optimizer_state(torch.device("cpu"))
self.precision_plugin.teardown()
@classmethod @classmethod
def register_plugins(cls, plugin_registry) -> None: def register_plugins(cls, plugin_registry) -> None:

View File

@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):
model = BoringModel() model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False) tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
@pytest.mark.parametrize("tpu_core", [1, 5]) @pytest.mark.parametrize("tpu_core", [1, 5])
@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel() model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False) tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}" assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
@RunIf(tpu=True) @RunIf(tpu=True)

View File

View File

@ -0,0 +1,25 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin
def test_teardown():
plugin = TPUBf16PrecisionPlugin()
plugin.connect(Mock(), Mock(), Mock())
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ