Prepare for ShardedTensor deprecation (#16892)
This commit is contained in:
parent
24c0cd738c
commit
a00e061417
|
@ -391,6 +391,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))
|
- Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))
|
||||||
|
|
||||||
|
|
||||||
|
- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
|
- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
|
||||||
|
|
|
@ -35,7 +35,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors
|
||||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||||
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
|
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
|
||||||
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0
|
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
|
||||||
from lightning.fabric.wrappers import _FabricOptimizer
|
from lightning.fabric.wrappers import _FabricOptimizer
|
||||||
from lightning.pytorch.callbacks.callback import Callback
|
from lightning.pytorch.callbacks.callback import Callback
|
||||||
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
||||||
|
@ -1440,6 +1440,9 @@ class LightningModule(
|
||||||
|
|
||||||
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
|
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
|
||||||
"""
|
"""
|
||||||
|
if _TORCH_GREATER_EQUAL_2_1:
|
||||||
|
# ShardedTensor is deprecated in favor of DistributedTensor
|
||||||
|
return
|
||||||
if _IS_WINDOWS or not torch.distributed.is_available():
|
if _IS_WINDOWS or not torch.distributed.is_available():
|
||||||
rank_zero_debug("Could not register sharded tensor state dict hooks")
|
rank_zero_debug("Could not register sharded tensor state dict hooks")
|
||||||
return
|
return
|
||||||
|
|
|
@ -310,7 +310,7 @@ def test_device_placement(tmpdir, accelerator, device):
|
||||||
assert_device(torch.device("cpu"))
|
assert_device(torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
@RunIf(skip_windows=True)
|
@RunIf(skip_windows=True, max_torch="2.1.0")
|
||||||
def test_sharded_tensor_state_dict(single_process_pg):
|
def test_sharded_tensor_state_dict(single_process_pg):
|
||||||
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
|
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
|
||||||
from torch.distributed._sharding_spec import ChunkShardingSpec
|
from torch.distributed._sharding_spec import ChunkShardingSpec
|
||||||
|
|
Loading…
Reference in New Issue