Fix BF16 teardown for TPU precision plugin (#10990)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
235efb37d7
commit
ba8e7cd787
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
|
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
|
||||||
|
|
||||||
|
|
||||||
|
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
||||||
|
@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
|
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
|
||||||
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
|
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
|
||||||
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
|
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
|
||||||
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
|
|
||||||
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
|
|
||||||
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
|
|
||||||
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
|
|
||||||
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
|
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
|
||||||
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
|
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
|
||||||
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
|
|
||||||
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
|
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
|
||||||
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
|
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
|
||||||
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
|
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
|
||||||
|
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
|
||||||
|
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
|
||||||
|
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
|
||||||
|
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
|
||||||
|
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
|
||||||
|
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
|
||||||
|
|
||||||
|
|
||||||
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
|
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
|
||||||
|
@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
|
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## [1.5.7] - 2021-12-21
|
## [1.5.7] - 2021-12-21
|
||||||
|
|
||||||
|
|
|
@ -236,3 +236,9 @@ class PrecisionPlugin(CheckpointHooks):
|
||||||
"""A contextmanager for the predict step."""
|
"""A contextmanager for the predict step."""
|
||||||
with self.forward_context():
|
with self.forward_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def teardown(self) -> None:
|
||||||
|
"""This method is called to teardown the training process.
|
||||||
|
|
||||||
|
It is the right place to release memory and free other resources.
|
||||||
|
"""
|
||||||
|
|
|
@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
|
||||||
def connect(
|
def connect(
|
||||||
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
|
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
|
||||||
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
|
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
|
||||||
os.environ["XLA_USE_BF16"] = str(1)
|
os.environ["XLA_USE_BF16"] = "1"
|
||||||
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
|
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
|
||||||
|
|
||||||
|
def teardown(self) -> None:
|
||||||
|
os.environ.pop("XLA_USE_BF16", None)
|
||||||
|
|
|
@ -86,6 +86,7 @@ class SingleDeviceStrategy(Strategy):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def teardown(self) -> None:
|
def teardown(self) -> None:
|
||||||
|
super().teardown()
|
||||||
if self.on_gpu:
|
if self.on_gpu:
|
||||||
# GPU teardown
|
# GPU teardown
|
||||||
self.lightning_module.cpu()
|
self.lightning_module.cpu()
|
||||||
|
|
|
@ -74,6 +74,7 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
||||||
self.model.to(self.root_device)
|
self.model.to(self.root_device)
|
||||||
|
|
||||||
def teardown(self) -> None:
|
def teardown(self) -> None:
|
||||||
|
super().teardown()
|
||||||
# TPU teardown
|
# TPU teardown
|
||||||
os.environ.pop("PT_XLA_DEBUG", None)
|
os.environ.pop("PT_XLA_DEBUG", None)
|
||||||
|
|
||||||
|
|
|
@ -244,9 +244,6 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
||||||
}
|
}
|
||||||
|
|
||||||
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
|
||||||
# todo: precision pluging is call in accelerator setup and should be moved
|
|
||||||
if "XLA_USE_BF16" in os.environ:
|
|
||||||
del os.environ["XLA_USE_BF16"]
|
|
||||||
context = mp.get_context(self.start_method or "fork")
|
context = mp.get_context(self.start_method or "fork")
|
||||||
return_queue = context.SimpleQueue()
|
return_queue = context.SimpleQueue()
|
||||||
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
|
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
|
||||||
|
@ -340,6 +337,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
||||||
return xm.all_gather(tensor)
|
return xm.all_gather(tensor)
|
||||||
|
|
||||||
def teardown(self) -> None:
|
def teardown(self) -> None:
|
||||||
|
super().teardown()
|
||||||
os.environ.pop("PT_XLA_DEBUG", None)
|
os.environ.pop("PT_XLA_DEBUG", None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -437,13 +437,13 @@ class Strategy(ABC):
|
||||||
"""
|
"""
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def teardown(self) -> None:
|
def teardown(self) -> None:
|
||||||
"""This method is called to teardown the training process.
|
"""This method is called to teardown the training process.
|
||||||
|
|
||||||
It is the right place to release memory and free other resources.
|
It is the right place to release memory and free other resources.
|
||||||
"""
|
"""
|
||||||
self._move_optimizer_state(torch.device("cpu"))
|
self._move_optimizer_state(torch.device("cpu"))
|
||||||
|
self.precision_plugin.teardown()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_plugins(cls, plugin_registry) -> None:
|
def register_plugins(cls, plugin_registry) -> None:
|
||||||
|
|
|
@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):
|
||||||
|
|
||||||
model = BoringModel()
|
model = BoringModel()
|
||||||
tpipes.run_model_test(trainer_options, model, on_gpu=False)
|
tpipes.run_model_test(trainer_options, model, on_gpu=False)
|
||||||
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tpu_core", [1, 5])
|
@pytest.mark.parametrize("tpu_core", [1, 5])
|
||||||
|
@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
|
||||||
model = BoringModel()
|
model = BoringModel()
|
||||||
tpipes.run_model_test(trainer_options, model, on_gpu=False)
|
tpipes.run_model_test(trainer_options, model, on_gpu=False)
|
||||||
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
|
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
|
||||||
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
|
|
||||||
|
|
||||||
|
|
||||||
@RunIf(tpu=True)
|
@RunIf(tpu=True)
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def test_teardown():
|
||||||
|
plugin = TPUBf16PrecisionPlugin()
|
||||||
|
plugin.connect(Mock(), Mock(), Mock())
|
||||||
|
assert os.environ.get("XLA_USE_BF16") == "1"
|
||||||
|
plugin.teardown()
|
||||||
|
assert "XLA_USE_BF16" not in os.environ
|
Loading…
Reference in New Issue