From 8b30981b1068eda731f8789e8e39c26893d02348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 15:02:26 +0100 Subject: [PATCH] Remove `start_{training,evaluating,predicting}` from `HorovodPlugin` (#10989) --- CHANGELOG.md | 4 ++ .../plugins/training_type/horovod.py | 38 ++++++------------- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8be044ba41..3da34016f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -236,6 +236,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890)) +- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989)) + + + ### Fixed - diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 84a41d5a5f..184183f577 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -52,6 +52,7 @@ class HorovodPlugin(ParallelPlugin): precision_plugin=precision_plugin, ) rank_zero_only.rank = self.global_rank + self._exit_stack: Optional[ExitStack] = None @property def global_rank(self) -> int: @@ -80,6 +81,9 @@ class HorovodPlugin(ParallelPlugin): def pre_dispatch(self, trainer: "pl.Trainer") -> None: super().pre_dispatch(trainer) + self._exit_stack = ExitStack() + self._exit_stack.__enter__() + if not self.lightning_module.trainer.training: # no need to setup optimizers return @@ -109,33 +113,9 @@ class HorovodPlugin(ParallelPlugin): hvd.broadcast_optimizer_state(optimizer, root_rank=0) self.optimizers = self._wrap_optimizers(optimizers) - - def start_training(self, trainer): - with ExitStack() as stack: - for optimizer in trainer.optimizers: - # Synchronization will be performed explicitly following backward() - stack.enter_context(optimizer.skip_synchronize()) - - # set up training routine - self._results = trainer.run_stage() - - # Make sure all workers have finished training before returning to the user - self.join() - - def start_evaluating(self, trainer): - with ExitStack(): - self._results = trainer.run_stage() - - # Make sure all workers have finished training before returning to the user - self.join() - - def start_predicting(self, trainer): - with ExitStack(): - # set up training routine - self._results = trainer.run_stage() - - # Make sure all workers have finished training before returning to the user - self.join() + for optimizer in self.optimizers: + # Synchronization will be performed explicitly following backward() + self._exit_stack.enter_context(optimizer.skip_synchronize()) def barrier(self, *args, **kwargs): if distributed_available(): @@ -218,6 +198,10 @@ class HorovodPlugin(ParallelPlugin): def teardown(self) -> None: super().teardown() + self._exit_stack.__exit__(None, None, None) + self._exit_stack = None + # Make sure all workers have finished training before returning to the user + self.join() if self.on_gpu: # GPU teardown self.lightning_module.cpu()