3/n Move Accelerator into strategy - remove model_sharded_context() (#10886)
* 3/n Move Accelerator into strategy - remove model_sharded_context() * update ttp function * update changelog * update changelog Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
This commit is contained in:
parent
44cd412e91
commit
45dd8066e7
|
@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
|
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
|
||||||
|
|
||||||
|
|
||||||
|
- Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886))
|
||||||
|
|
||||||
|
|
||||||
- Removed method `pre_dispatch` from the `PrecisionPlugin` method ([#10887](https://github.com/PyTorchLightning/pytorch-lightning/pull/10887))
|
- Removed method `pre_dispatch` from the `PrecisionPlugin` method ([#10887](https://github.com/PyTorchLightning/pytorch-lightning/pull/10887))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, Generator, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
@ -154,18 +153,6 @@ class Accelerator:
|
||||||
with self.training_type_plugin.precision_plugin.predict_step_context():
|
with self.training_type_plugin.precision_plugin.predict_step_context():
|
||||||
return self.training_type_plugin.predict_step(*step_kwargs.values())
|
return self.training_type_plugin.predict_step(*step_kwargs.values())
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def model_sharded_context(self) -> Generator[None, None, None]:
|
|
||||||
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to.
|
|
||||||
|
|
||||||
shard the model instantly - useful for extremely large models. Can save memory and
|
|
||||||
initialization time.
|
|
||||||
Returns:
|
|
||||||
Model parallel context.
|
|
||||||
"""
|
|
||||||
with self.training_type_plugin.model_sharded_context():
|
|
||||||
yield
|
|
||||||
|
|
||||||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
|
||||||
"""Gets stats for a given device.
|
"""Gets stats for a given device.
|
||||||
|
|
||||||
|
|
|
@ -1406,7 +1406,7 @@ class Trainer(
|
||||||
self.training_type_plugin.barrier("post_setup")
|
self.training_type_plugin.barrier("post_setup")
|
||||||
|
|
||||||
def _call_configure_sharded_model(self) -> None:
|
def _call_configure_sharded_model(self) -> None:
|
||||||
with self.accelerator.model_sharded_context():
|
with self.training_type_plugin.model_sharded_context():
|
||||||
self._handle_meta_model()
|
self._handle_meta_model()
|
||||||
self.call_hook("configure_sharded_model")
|
self.call_hook("configure_sharded_model")
|
||||||
self.call_hook("on_configure_sharded_model")
|
self.call_hook("on_configure_sharded_model")
|
||||||
|
|
Loading…
Reference in New Issue