3/n Move Accelerator into strategy - remove model_sharded_context() (#10886)
* 3/n Move Accelerator into strategy - remove model_sharded_context() * update ttp function * update changelog * update changelog Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
44cd412e91
commit
45dd8066e7
|
@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
|
||||
|
||||
|
||||
- Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886))
|
||||
|
||||
|
||||
- Removed method `pre_dispatch` from the `PrecisionPlugin` method ([#10887](https://github.com/PyTorchLightning/pytorch-lightning/pull/10887))
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Generator, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
@ -154,18 +153,6 @@ class Accelerator:
|
|||
with self.training_type_plugin.precision_plugin.predict_step_context():
|
||||
return self.training_type_plugin.predict_step(*step_kwargs.values())
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_sharded_context(self) -> Generator[None, None, None]:
|
||||
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to.
|
||||
|
||||
shard the model instantly - useful for extremely large models. Can save memory and
|
||||
initialization time.
|
||||
Returns:
|
||||
Model parallel context.
|
||||
"""
|
||||
with self.training_type_plugin.model_sharded_context():
|
||||
yield
|
||||
|
||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||
"""Gets stats for a given device.
|
||||
|
||||
|
|
|
@ -1406,7 +1406,7 @@ class Trainer(
|
|||
self.training_type_plugin.barrier("post_setup")
|
||||
|
||||
def _call_configure_sharded_model(self) -> None:
|
||||
with self.accelerator.model_sharded_context():
|
||||
with self.training_type_plugin.model_sharded_context():
|
||||
self._handle_meta_model()
|
||||
self.call_hook("configure_sharded_model")
|
||||
self.call_hook("on_configure_sharded_model")
|
||||
|
|
Loading…
Reference in New Issue