3/n Move Accelerator into strategy - remove model_sharded_context() (#10886)

* 3/n Move Accelerator into strategy - remove model_sharded_context()

* update ttp function

* update changelog

* update changelog

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
four4fish 2021-12-01 19:34:51 -08:00 committed by GitHub
parent 44cd412e91
commit 45dd8066e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 15 deletions

View File

@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
- Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886))
- Removed method `pre_dispatch` from the `PrecisionPlugin` method ([#10887](https://github.com/PyTorchLightning/pytorch-lightning/pull/10887))

View File

@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import abstractmethod
from typing import Any, Dict, Generator, Optional, Union
from typing import Any, Dict, Optional, Union
import torch
from torch.nn import Module
@ -154,18 +153,6 @@ class Accelerator:
with self.training_type_plugin.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())
@contextlib.contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to.
shard the model instantly - useful for extremely large models. Can save memory and
initialization time.
Returns:
Model parallel context.
"""
with self.training_type_plugin.model_sharded_context():
yield
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.

View File

@ -1406,7 +1406,7 @@ class Trainer(
self.training_type_plugin.barrier("post_setup")
def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
with self.training_type_plugin.model_sharded_context():
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")