From 45dd8066e7257ffb24bebd1cbad356327fcecfae Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Wed, 1 Dec 2021 19:34:51 -0800 Subject: [PATCH] 3/n Move Accelerator into strategy - remove model_sharded_context() (#10886) * 3/n Move Accelerator into strategy - remove model_sharded_context() * update ttp function * update changelog * update changelog Co-authored-by: ananthsub --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 15 +-------------- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4c03b70cc..7e29976cae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886)) + + - Removed method `pre_dispatch` from the `PrecisionPlugin` method ([#10887](https://github.com/PyTorchLightning/pytorch-lightning/pull/10887)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 502bed9870..b4bc9e5130 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from abc import abstractmethod -from typing import Any, Dict, Generator, Optional, Union +from typing import Any, Dict, Optional, Union import torch from torch.nn import Module @@ -154,18 +153,6 @@ class Accelerator: with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) - @contextlib.contextmanager - def model_sharded_context(self) -> Generator[None, None, None]: - """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. - - shard the model instantly - useful for extremely large models. Can save memory and - initialization time. - Returns: - Model parallel context. - """ - with self.training_type_plugin.model_sharded_context(): - yield - def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for a given device. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f1349cb86d..b88d8e4ff5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1406,7 +1406,7 @@ class Trainer( self.training_type_plugin.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - with self.accelerator.model_sharded_context(): + with self.training_type_plugin.model_sharded_context(): self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model")