Remove `start_{training,evaluating,predicting}` from `HorovodPlugin` (#10989)
This commit is contained in:
parent
01f5f99919
commit
8b30981b10
|
@ -236,6 +236,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))
|
||||
|
||||
|
||||
- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989))
|
||||
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
-
|
||||
|
|
|
@ -52,6 +52,7 @@ class HorovodPlugin(ParallelPlugin):
|
|||
precision_plugin=precision_plugin,
|
||||
)
|
||||
rank_zero_only.rank = self.global_rank
|
||||
self._exit_stack: Optional[ExitStack] = None
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
|
@ -80,6 +81,9 @@ class HorovodPlugin(ParallelPlugin):
|
|||
|
||||
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
super().pre_dispatch(trainer)
|
||||
self._exit_stack = ExitStack()
|
||||
self._exit_stack.__enter__()
|
||||
|
||||
if not self.lightning_module.trainer.training:
|
||||
# no need to setup optimizers
|
||||
return
|
||||
|
@ -109,33 +113,9 @@ class HorovodPlugin(ParallelPlugin):
|
|||
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
|
||||
|
||||
self.optimizers = self._wrap_optimizers(optimizers)
|
||||
|
||||
def start_training(self, trainer):
|
||||
with ExitStack() as stack:
|
||||
for optimizer in trainer.optimizers:
|
||||
# Synchronization will be performed explicitly following backward()
|
||||
stack.enter_context(optimizer.skip_synchronize())
|
||||
|
||||
# set up training routine
|
||||
self._results = trainer.run_stage()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
self.join()
|
||||
|
||||
def start_evaluating(self, trainer):
|
||||
with ExitStack():
|
||||
self._results = trainer.run_stage()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
self.join()
|
||||
|
||||
def start_predicting(self, trainer):
|
||||
with ExitStack():
|
||||
# set up training routine
|
||||
self._results = trainer.run_stage()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
self.join()
|
||||
for optimizer in self.optimizers:
|
||||
# Synchronization will be performed explicitly following backward()
|
||||
self._exit_stack.enter_context(optimizer.skip_synchronize())
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
if distributed_available():
|
||||
|
@ -218,6 +198,10 @@ class HorovodPlugin(ParallelPlugin):
|
|||
|
||||
def teardown(self) -> None:
|
||||
super().teardown()
|
||||
self._exit_stack.__exit__(None, None, None)
|
||||
self._exit_stack = None
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
self.join()
|
||||
if self.on_gpu:
|
||||
# GPU teardown
|
||||
self.lightning_module.cpu()
|
||||
|
|
Loading…
Reference in New Issue