Rename `ParallelPlugin` to `ParallelStrategy` (#11123)
This commit is contained in:
parent
4bfe5bda0f
commit
17ad1a4c00
|
@ -137,6 +137,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
- Renamed the `TrainingTypePlugin` to `Strategy` ([#11120](https://github.com/PyTorchLightning/pytorch-lightning/pull/11120))
|
||||
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
|
||||
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
|
||||
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
|
||||
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
|
||||
|
@ -151,7 +152,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
|
||||
|
||||
|
||||
|
||||
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
|
||||
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ Training Type Plugins
|
|||
|
||||
Strategy
|
||||
SingleDeviceStrategy
|
||||
ParallelPlugin
|
||||
ParallelStrategy
|
||||
DataParallelStrategy
|
||||
DDPStrategy
|
||||
DDP2Strategy
|
||||
|
|
|
@ -107,7 +107,7 @@ Training Type Plugins
|
|||
|
||||
Strategy
|
||||
SingleDeviceStrategy
|
||||
ParallelPlugin
|
||||
ParallelStrategy
|
||||
DataParallelStrategy
|
||||
DDPStrategy
|
||||
DDP2Strategy
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
from pytorch_lightning.plugins import ParallelStrategy
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
|
@ -161,8 +161,8 @@ def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -
|
|||
|
||||
@contextmanager
|
||||
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
|
||||
"""Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. This is
|
||||
useful for example when when accumulating gradients to reduce communication when it is not needed.
|
||||
"""Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelStrategy`. This
|
||||
is useful for example when when accumulating gradients to reduce communication when it is not needed.
|
||||
|
||||
Args:
|
||||
trainer: the trainer instance with a reference to a training type plugin
|
||||
|
@ -171,7 +171,7 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) ->
|
|||
Returns:
|
||||
context manager with sync behaviour off
|
||||
"""
|
||||
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
|
||||
if isinstance(trainer.training_type_plugin, ParallelStrategy) and block:
|
||||
with trainer.training_type_plugin.block_backward_sync():
|
||||
yield None
|
||||
else:
|
||||
|
|
|
@ -27,7 +27,7 @@ from pytorch_lightning.plugins.training_type.dp import DataParallelStrategy
|
|||
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedStrategy
|
||||
from pytorch_lightning.plugins.training_type.horovod import HorovodStrategy
|
||||
from pytorch_lightning.plugins.training_type.ipu import IPUStrategy
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.plugins.training_type.sharded import DDPShardedStrategy
|
||||
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedStrategy
|
||||
from pytorch_lightning.plugins.training_type.single_device import SingleDeviceStrategy
|
||||
|
@ -64,7 +64,7 @@ __all__ = [
|
|||
"TPUBf16PrecisionPlugin",
|
||||
"TPUSpawnStrategy",
|
||||
"Strategy",
|
||||
"ParallelPlugin",
|
||||
"ParallelStrategy",
|
||||
"DDPShardedStrategy",
|
||||
"DDPSpawnShardedStrategy",
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@ from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedStrategy
|
|||
from pytorch_lightning.plugins.training_type.dp import DataParallelStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.horovod import HorovodStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.sharded import DDPShardedStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedStrategy # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.single_device import SingleDeviceStrategy # noqa: F401
|
||||
|
|
|
@ -37,7 +37,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import (
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
|
@ -73,7 +73,7 @@ if _TORCH_GREATER_EQUAL_1_8:
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DDPStrategy(ParallelPlugin):
|
||||
class DDPStrategy(ParallelStrategy):
|
||||
"""Plugin for multi-process single-device training on one or multiple nodes.
|
||||
|
||||
The main process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of
|
||||
|
|
|
@ -30,7 +30,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
|
@ -54,7 +54,7 @@ if _TORCH_GREATER_EQUAL_1_8:
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DDPSpawnStrategy(ParallelPlugin):
|
||||
class DDPSpawnStrategy(ParallelStrategy):
|
||||
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
|
||||
finishes."""
|
||||
|
||||
|
|
|
@ -20,14 +20,14 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, STEP_OUTPUT
|
||||
|
||||
|
||||
class DataParallelStrategy(ParallelPlugin):
|
||||
class DataParallelStrategy(ParallelStrategy):
|
||||
"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
|
||||
each gets a split of the data."""
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as dist_group
|
||||
|
@ -34,7 +34,7 @@ if _HOROVOD_AVAILABLE:
|
|||
import horovod.torch as hvd
|
||||
|
||||
|
||||
class HorovodStrategy(ParallelPlugin):
|
||||
class HorovodStrategy(ParallelStrategy):
|
||||
"""Plugin for Horovod distributed training integration."""
|
||||
|
||||
distributed_backend = _StrategyType.HOROVOD
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelStrategy
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -57,7 +57,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
|
|||
return batch
|
||||
|
||||
|
||||
class IPUStrategy(ParallelPlugin):
|
||||
class IPUStrategy(ParallelStrategy):
|
||||
"""Plugin for training on IPU devices."""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -29,7 +29,7 @@ from pytorch_lightning.utilities import _XLA_AVAILABLE
|
|||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
|
||||
|
||||
|
||||
class ParallelPlugin(Strategy, ABC):
|
||||
class ParallelStrategy(Strategy, ABC):
|
||||
"""Plugin for training with multiple processes in parallel."""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -45,7 +45,7 @@ from pytorch_lightning.plugins import (
|
|||
ApexMixedPrecisionPlugin,
|
||||
DDPSpawnStrategy,
|
||||
NativeMixedPrecisionPlugin,
|
||||
ParallelPlugin,
|
||||
ParallelStrategy,
|
||||
PLUGIN_INPUT,
|
||||
PrecisionPlugin,
|
||||
Strategy,
|
||||
|
@ -1927,7 +1927,7 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Optional[dict]:
|
||||
if isinstance(self.training_type_plugin, ParallelPlugin):
|
||||
if isinstance(self.training_type_plugin, ParallelStrategy):
|
||||
return self.training_type_plugin.distributed_sampler_kwargs
|
||||
|
||||
@property
|
||||
|
|
|
@ -32,7 +32,7 @@ from pytorch_lightning.plugins import (
|
|||
DDPSpawnStrategy,
|
||||
DDPStrategy,
|
||||
DeepSpeedStrategy,
|
||||
ParallelPlugin,
|
||||
ParallelStrategy,
|
||||
PrecisionPlugin,
|
||||
SingleDeviceStrategy,
|
||||
)
|
||||
|
@ -427,7 +427,7 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str):
|
|||
@mock.patch("torch.cuda.device_count", return_value=2)
|
||||
@pytest.mark.parametrize("gpus", [1, 2])
|
||||
def test_accelerator_choice_multi_node_gpu(
|
||||
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin, gpus: int
|
||||
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelStrategy, gpus: int
|
||||
):
|
||||
with pytest.deprecated_call(match=r"accelerator=.*\)` has been deprecated"):
|
||||
trainer = Trainer(accelerator=accelerator, default_root_dir=tmpdir, num_nodes=2, gpus=gpus)
|
||||
|
|
|
@ -23,7 +23,7 @@ from tests.helpers import BoringModel
|
|||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
class CustomParallelPlugin(DDPStrategy):
|
||||
class CustomParallelStrategy(DDPStrategy):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Set to None so it will be overwritten by the accelerator connector.
|
||||
|
@ -34,7 +34,7 @@ class CustomParallelPlugin(DDPStrategy):
|
|||
def test_sync_batchnorm_set(tmpdir):
|
||||
"""Tests if sync_batchnorm is automatically set for custom plugin."""
|
||||
model = BoringModel()
|
||||
plugin = CustomParallelPlugin()
|
||||
plugin = CustomParallelStrategy()
|
||||
assert plugin.sync_batchnorm is None
|
||||
trainer = Trainer(max_epochs=1, strategy=plugin, default_root_dir=tmpdir, sync_batchnorm=True)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue