Remove `start_{training,evaluating,predicting}` from `HorovodPlugin` (#10989)

This commit is contained in:
Adrian Wälchli 2021-12-08 15:02:26 +01:00 committed by GitHub
parent 01f5f99919
commit 8b30981b10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 27 deletions

View File

@ -236,6 +236,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))
- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989))
### Fixed
-

View File

@ -52,6 +52,7 @@ class HorovodPlugin(ParallelPlugin):
precision_plugin=precision_plugin,
)
rank_zero_only.rank = self.global_rank
self._exit_stack: Optional[ExitStack] = None
@property
def global_rank(self) -> int:
@ -80,6 +81,9 @@ class HorovodPlugin(ParallelPlugin):
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
super().pre_dispatch(trainer)
self._exit_stack = ExitStack()
self._exit_stack.__enter__()
if not self.lightning_module.trainer.training:
# no need to setup optimizers
return
@ -109,33 +113,9 @@ class HorovodPlugin(ParallelPlugin):
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
self.optimizers = self._wrap_optimizers(optimizers)
def start_training(self, trainer):
with ExitStack() as stack:
for optimizer in trainer.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())
# set up training routine
self._results = trainer.run_stage()
# Make sure all workers have finished training before returning to the user
self.join()
def start_evaluating(self, trainer):
with ExitStack():
self._results = trainer.run_stage()
# Make sure all workers have finished training before returning to the user
self.join()
def start_predicting(self, trainer):
with ExitStack():
# set up training routine
self._results = trainer.run_stage()
# Make sure all workers have finished training before returning to the user
self.join()
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
self._exit_stack.enter_context(optimizer.skip_synchronize())
def barrier(self, *args, **kwargs):
if distributed_available():
@ -218,6 +198,10 @@ class HorovodPlugin(ParallelPlugin):
def teardown(self) -> None:
super().teardown()
self._exit_stack.__exit__(None, None, None)
self._exit_stack = None
# Make sure all workers have finished training before returning to the user
self.join()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()