Remove deprecated `task_idx` (#10441)
This commit is contained in:
parent
ebab4be3e4
commit
c413b69240
|
@ -90,6 +90,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed deprecated `LightningModule.log(tbptt_reduce_fx, tbptt_reduce_token, sync_dist_op)` ([#10423](https://github.com/PyTorchLightning/pytorch-lightning/pull/10423))
|
||||
|
||||
|
||||
- Removed deprecated `Plugin.task_idx` ([#10441](https://github.com/PyTorchLightning/pytorch-lightning/pull/10441))
|
||||
|
||||
|
||||
- Removed PyTorch 1.6 support ([#10367](https://github.com/PyTorchLightning/pytorch-lightning/pull/10367))
|
||||
|
||||
|
||||
|
|
|
@ -47,7 +47,6 @@ from pytorch_lightning.utilities import (
|
|||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
_TORCH_GREATER_EQUAL_1_10,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
|
@ -105,7 +104,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
self.sync_batchnorm = False
|
||||
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||
self._ddp_kwargs = kwargs
|
||||
self._task_idx = None
|
||||
self._ddp_comm_state = ddp_comm_state
|
||||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
|
@ -133,18 +131,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
self._num_nodes = num_nodes
|
||||
self.set_world_ranks()
|
||||
|
||||
@property
|
||||
def task_idx(self) -> Optional[int]:
|
||||
rank_zero_deprecation(
|
||||
f"`{self.__class__.__name__}.task_idx` is deprecated in v1.4 and will be removed in v1.6. Use "
|
||||
f"`{self.__class__.__name__}.local_rank` instead."
|
||||
)
|
||||
return self._task_idx
|
||||
|
||||
@task_idx.setter
|
||||
def task_idx(self, task_idx: int) -> None:
|
||||
self._task_idx = task_idx
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self):
|
||||
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
|
||||
|
@ -159,9 +145,6 @@ class DDPPlugin(ParallelPlugin):
|
|||
if not self.cluster_environment.creates_processes_externally:
|
||||
self._call_children_scripts()
|
||||
|
||||
# set the task idx
|
||||
self.task_idx = self.cluster_environment.local_rank()
|
||||
|
||||
self.setup_distributed()
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
|
|
|
@ -32,11 +32,6 @@ class DDP2Plugin(DDPPlugin):
|
|||
def world_size(self) -> int:
|
||||
return self.num_nodes
|
||||
|
||||
def setup(self) -> None:
|
||||
# set the task idx
|
||||
self.task_idx = self.cluster_environment.local_rank()
|
||||
# the difference to DDP is that we don't call children processes here
|
||||
|
||||
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
|
||||
"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2,
|
||||
the reduction here is only across local devices within the node.
|
||||
|
|
|
@ -17,7 +17,6 @@ from unittest.mock import call, Mock
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary
|
||||
|
@ -86,12 +85,6 @@ def test_v1_6_0_rank_zero_warnings_moved():
|
|||
rank_zero_deprecation("test")
|
||||
|
||||
|
||||
def test_v1_6_0_ddp_plugin_task_idx():
|
||||
plugin = DDPPlugin()
|
||||
with pytest.deprecated_call(match="Use `DDPPlugin.local_rank` instead"):
|
||||
_ = plugin.task_idx
|
||||
|
||||
|
||||
def test_v1_6_0_deprecated_model_summary_mode(tmpdir):
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"):
|
||||
|
|
Loading…
Reference in New Issue