From 73e7a5d0c214072b93a9adb4b3c8756c54ec14f9 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 23 Jun 2022 16:44:48 +0100 Subject: [PATCH] Rename `CollaborativeStrategy` to `HivemindStrategy` (#13388) --- CHANGELOG.md | 9 ++-- docs/source-pytorch/api_references.rst | 2 +- docs/source-pytorch/extensions/strategy.rst | 2 +- .../strategies/collaborative_training.rst | 2 +- .../collaborative_training_basic.rst | 10 ++-- .../collaborative_training_expert.rst | 12 ++--- .../collaborative_training_intermediate.rst | 20 ++++---- src/pytorch_lightning/strategies/__init__.py | 2 +- .../strategies/collaborative.py | 46 +++++++++---------- .../strategies/test_collaborative.py | 38 ++++++++------- 10 files changed, 72 insertions(+), 71 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 906603f55e..fd81189861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,9 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added profiling to the loops' dataloader `__next__` calls ([#12124](https://github.com/PyTorchLightning/pytorch-lightning/pull/12124)) - -- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842)) - +- Hivemind Strategy + * Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842)) + * Renamed `CollaborativeStrategy` to `HivemindStrategy` ([#13388](https://github.com/PyTorchLightning/pytorch-lightning/pull/13388)) - Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902)) @@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) +- + + ### Deprecated - Deprecated `pytorch_lightning.loggers.base.LightningLoggerBase` in favor of `pytorch_lightning.loggers.logger.Logger`, and deprecated `pytorch_lightning.loggers.base` in favor of `pytorch_lightning.loggers.logger` ([#120148](https://github.com/PyTorchLightning/pytorch-lightning/pull/12014)) diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 15640bc3ca..401ba3928c 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -269,7 +269,7 @@ strategies :template: classtemplate.rst BaguaStrategy - CollaborativeStrategy + HivemindStrategy DDP2Strategy DDPFullyShardedStrategy DDPShardedStrategy diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 987764ed95..95c48e0949 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -76,7 +76,7 @@ The below table lists all relevant strategies available in Lightning with their - :class:`~pytorch_lightning.strategies.BaguaStrategy` - Strategy for training using the Bagua library, with advanced distributed training algorithms and system optimizations. :ref:`Learn more. ` * - collaborative - - :class:`~pytorch_lightning.strategies.CollaborativeStrategy` + - :class:`~pytorch_lightning.strategies.HivemindStrategy` - Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. ` * - fsdp - :class:`~pytorch_lightning.strategies.DDPFullyShardedStrategy` diff --git a/docs/source-pytorch/strategies/collaborative_training.rst b/docs/source-pytorch/strategies/collaborative_training.rst index 1ffa8b0815..72e9d13f91 100644 --- a/docs/source-pytorch/strategies/collaborative_training.rst +++ b/docs/source-pytorch/strategies/collaborative_training.rst @@ -23,7 +23,7 @@ Training on unreliable mixed GPUs across the internet .. displayitem:: :header: 2: Speed up training by enabling under-the-hood optimizations - :description: Learn which flags to use with the CollaborativeStrategy to speed up training. + :description: Learn which flags to use with the HivemindStrategy to speed up training. :col_css: col-md-4 :button_link: collaborative_training_intermediate.html :height: 200 diff --git a/docs/source-pytorch/strategies/collaborative_training_basic.rst b/docs/source-pytorch/strategies/collaborative_training_basic.rst index 5fd64cd968..108f6197fd 100644 --- a/docs/source-pytorch/strategies/collaborative_training_basic.rst +++ b/docs/source-pytorch/strategies/collaborative_training_basic.rst @@ -16,20 +16,20 @@ To use Collaborative Training, you need to first install Hivemind. pip install hivemind -The ``CollaborativeStrategy`` accumulates gradients from all processes that are collaborating until they reach a ``target_batch_size``. By default, we use the batch size +The ``HivemindStrategy`` accumulates gradients from all processes that are collaborating until they reach a ``target_batch_size``. By default, we use the batch size of the first batch to determine what each local machine batch contributes towards the ``target_batch_size``. Once the ``target_batch_size`` is reached, an optimizer step is made on all processes. .. warning:: - When using ``CollaborativeStrategy`` note that you cannot use gradient accumulation (``accumulate_grad_batches``). This is because Hivemind manages accumulation internally. + When using ``HivemindStrategy`` note that you cannot use gradient accumulation (``accumulate_grad_batches``). This is because Hivemind manages accumulation internally. .. code-block:: python import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy - trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=8192), accelerator="gpu", devices=1) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=8192), accelerator="gpu", devices=1) .. code-block:: bash @@ -37,7 +37,7 @@ is made on all processes. # Other machines can connect running the same command: # INITIAL_PEERS=... python train.py # or passing the peers to the strategy:" - # CollaborativeStrategy(initial_peers=...)" + # HivemindStrategy(initial_peers=...)" A helper message is printed once your training begins, which shows you how to start training on other machines using the same code. diff --git a/docs/source-pytorch/strategies/collaborative_training_expert.rst b/docs/source-pytorch/strategies/collaborative_training_expert.rst index 448720f619..5b8a5e8b4c 100644 --- a/docs/source-pytorch/strategies/collaborative_training_expert.rst +++ b/docs/source-pytorch/strategies/collaborative_training_expert.rst @@ -24,10 +24,10 @@ Below, we enable Float16 compression, which compresses gradients and states to F from hivemind import Float16Compression import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=target_batch_size, grad_compression=Float16Compression(), state_averaging_compression=Float16Compression(), @@ -44,14 +44,14 @@ Size Adaptive Compression has been used in a variety of Hivemind applications an from hivemind import Float16Compression, Uniform8BitQuantization import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy # compresses values above threshold with 8bit Quantization, lower with Float16 compression = SizeAdaptiveCompression( threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization() ) trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=target_batch_size, grad_compression=compression, state_averaging_compression=compression, @@ -73,12 +73,12 @@ In short, PowerSGD uses a low-rank approximation to compress gradients before ru .. code-block:: python import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy from functools import partial from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=8192, grad_averager_factory=partial(PowerSGDGradientAverager, averager_rank=32, min_compression_ratio=0.5), ), diff --git a/docs/source-pytorch/strategies/collaborative_training_intermediate.rst b/docs/source-pytorch/strategies/collaborative_training_intermediate.rst index 0536277aee..38d6c6a342 100644 --- a/docs/source-pytorch/strategies/collaborative_training_intermediate.rst +++ b/docs/source-pytorch/strategies/collaborative_training_intermediate.rst @@ -22,7 +22,7 @@ to overlap communication with computation. Enabling overlapping communication means convergence will slightly be affected. .. note:: - Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``CollaborativeStrategy`` instead of relying on a scheduler from ``configure_optimizers``. + Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``HivemindStrategy`` instead of relying on a scheduler from ``configure_optimizers``. The optimizer is re-created by Hivemind, and as a result, the scheduler has to be re-created. .. code-block:: python @@ -30,10 +30,10 @@ to overlap communication with computation. import torch from functools import partial import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=8192, delay_state_averaging=True, delay_grad_averaging=True, @@ -57,7 +57,7 @@ Offloading Optimizer State to the CPU Offloading the Optimizer state to the CPU works the same as :ref:`deepspeed-zero-stage-2-offload`, where we save GPU memory by keeping all optimizer states on the CPU. .. note:: - Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``CollaborativeStrategy`` instead of relying on a scheduler from ``configure_optimizers``. + Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``HivemindStrategy`` instead of relying on a scheduler from ``configure_optimizers``. The optimizer is re-created by Hivemind, and as a result, the scheduler has to be re-created. We suggest enabling offloading and overlapping communication to hide the additional overhead from having to communicate with the CPU. @@ -67,10 +67,10 @@ Offloading the Optimizer state to the CPU works the same as :ref:`deepspeed-zero import torch from functools import partial import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=8192, offload_optimizer=True, scheduler_fn=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=...), @@ -83,17 +83,17 @@ Offloading the Optimizer state to the CPU works the same as :ref:`deepspeed-zero Re-using Gradient Buffers """"""""""""""""""""""""" -By default, Hivemind accumulates gradients in a separate buffer. This means additional GPU memory is required to store gradients. You can enable re-using the model parameter gradient buffers by passing ``reuse_grad_buffers=True`` to the ``CollaborativeStrategy``. +By default, Hivemind accumulates gradients in a separate buffer. This means additional GPU memory is required to store gradients. You can enable re-using the model parameter gradient buffers by passing ``reuse_grad_buffers=True`` to the ``HivemindStrategy``. .. warning:: - The ``CollaborativeStrategy`` will override ``zero_grad`` in your ``LightningModule`` to have no effect. This is because gradients are accumulated in the model + The ``HivemindStrategy`` will override ``zero_grad`` in your ``LightningModule`` to have no effect. This is because gradients are accumulated in the model and Hivemind manages when they need to be cleared. .. code-block:: python import pytorch_lightning as pl - from pytorch_lightning.strategies import CollaborativeStrategy + from pytorch_lightning.strategies import HivemindStrategy trainer = pl.Trainer( - strategy=CollaborativeStrategy(target_batch_size=8192, reuse_grad_buffers=True), accelerator="gpu", devices=1 + strategy=HivemindStrategy(target_batch_size=8192, reuse_grad_buffers=True), accelerator="gpu", devices=1 ) diff --git a/src/pytorch_lightning/strategies/__init__.py b/src/pytorch_lightning/strategies/__init__.py index 0de3e51a0f..f59d976edf 100644 --- a/src/pytorch_lightning/strategies/__init__.py +++ b/src/pytorch_lightning/strategies/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401 -from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401 +from pytorch_lightning.strategies.collaborative import HivemindStrategy # noqa: F401 from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401 from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401 from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401 diff --git a/src/pytorch_lightning/strategies/collaborative.py b/src/pytorch_lightning/strategies/collaborative.py index 715f0d58e9..b594704aba 100644 --- a/src/pytorch_lightning/strategies/collaborative.py +++ b/src/pytorch_lightning/strategies/collaborative.py @@ -32,7 +32,7 @@ else: log = logging.getLogger(__name__) -class CollaborativeStrategy(Strategy): +class HivemindStrategy(Strategy): def __init__( self, target_batch_size: int, @@ -63,7 +63,7 @@ class CollaborativeStrategy(Strategy): with unreliable machines. For more information, `refer to the docs `__. - .. warning:: ``CollaborativeStrategy`` is experimental and subject to change. + .. warning:: ``HivemindStrategy`` is experimental and subject to change. Arguments: @@ -81,11 +81,11 @@ class CollaborativeStrategy(Strategy): corresponding :meth:`hivemind.Optimizer.step` call. delay_optimizer_step: Run optimizer in background, apply results in future .step. requires - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer`. + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.offload_optimizer`. delay_grad_averaging: Average gradients in background; requires - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer` and - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.delay_optimizer_step`. + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.offload_optimizer` and + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.delay_optimizer_step`. offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients. @@ -95,7 +95,7 @@ class CollaborativeStrategy(Strategy): scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler. When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn`` - is required to be passed to the ``CollaborativeStrategy``. This is because the optimizer + is required to be passed to the ``HivemindStrategy``. This is because the optimizer is re-created and the scheduler needs to be re-created as well. matchmaking_time: When looking for group, wait for peers to join for up to this many seconds. @@ -131,18 +131,18 @@ class CollaborativeStrategy(Strategy): port: When creating the endpoint, the host port to use. retry_endpoint_attempts: When connecting to the - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`, + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint`, how many time to retry before raising an exception. retry_endpoint_sleep_duration: When connecting to the - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`, + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint`, how long to wait between retries. **optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class. """ if not _HIVEMIND_AVAILABLE or platform.system() != "Linux": raise MisconfigurationException( - "To use the `CollaborativeStrategy`, you must have Hivemind installed and be running on Linux." + "To use the `HivemindStrategy`, you must have Hivemind installed and be running on Linux." " Install it by running `pip install -U hivemind`." ) @@ -271,7 +271,7 @@ class CollaborativeStrategy(Strategy): assert lightning_module is not None # `is_overridden` returns False otherwise rank_zero_warn( "You have overridden `optimizer_zero_grad` which will be disabled." - " When `CollaborativeStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad," + " When `HivemindStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad," " as this would delete the gradients before they are averaged." ) assert lightning_module is not None @@ -303,7 +303,7 @@ class CollaborativeStrategy(Strategy): raise MisconfigurationException( "We tried to infer the batch size from the first batch of data. " "Please provide the batch size to the Strategy by " - "``Trainer(strategy=CollaborativeStrategy(batch_size=x))``. " + "``Trainer(strategy=HivemindStrategy(batch_size=x))``. " ) from e self._initialize_hivemind() @@ -388,26 +388,26 @@ class DHTManager: Arguments: - host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host_maddrs` + host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.host_maddrs` - initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.initial_peers` + initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.initial_peers` - persistent: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.persistent` + persistent: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.persistent` - endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.endpoint` + endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.endpoint` - peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint` + peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint` - host: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host` + host: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.host` - port: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.port` + port: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.port` retry_endpoint_attempts: - :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_attempts` + :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.retry_endpoint_attempts` retry_endpoint_sleep_duration: :paramref: - `~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_sleep_duration` + `~pytorch_lightning.strategies.collaborative.HivemindStrategy.retry_endpoint_sleep_duration` """ self._persistent = persistent self._endpoint = endpoint @@ -445,7 +445,7 @@ class DHTManager: "\nOther machines can connect running the same command:\n" f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n" "or passing the peers to the strategy:\n" - f"CollaborativeStrategy(initial_peers='{','.join(visible_addresses)}')" + f"HivemindStrategy(initial_peers='{','.join(visible_addresses)}')" ) def _log_endpoint_helper_message(self, visible_addresses: List[str]) -> None: @@ -462,7 +462,7 @@ class DHTManager: "Other peers can connect via:\n" f"PEER_ENDPOINT={resolved_host}:{self._port} python ...\n" "or pass the peer endpoint address to the strategy:\n" - f"CollaborativeStrategy(peer_endpoint='{resolved_host}:{self._port}')" + f"HivemindStrategy(peer_endpoint='{resolved_host}:{self._port}')" ) def _start_server_process(self, host: str, port: int) -> None: @@ -499,7 +499,7 @@ class DHTManager: raise MisconfigurationException( f"Unable to get peers. Tried {retry_initial_peers} times waiting {retry_peer_sleep_duration}s." f"These parameters can be extended by passing " - "to the strategy (CollaborativeStrategy(retry_connection=x, retry_sleep_duration=y))." + "to the strategy (HivemindStrategy(retry_connection=x, retry_sleep_duration=y))." ) log.info(f"Received initial peers from collaborative server: {peers}") return peers diff --git a/tests/tests_pytorch/strategies/test_collaborative.py b/tests/tests_pytorch/strategies/test_collaborative.py index fefb5c13e0..8e4ef88f03 100644 --- a/tests/tests_pytorch/strategies/test_collaborative.py +++ b/tests/tests_pytorch/strategies/test_collaborative.py @@ -13,7 +13,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port -from pytorch_lightning.strategies import CollaborativeStrategy +from pytorch_lightning.strategies import HivemindStrategy from pytorch_lightning.strategies.collaborative import HiveMindScheduler from pytorch_lightning.utilities import _HIVEMIND_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -28,13 +28,13 @@ if _HIVEMIND_AVAILABLE: def test_raise_exception_if_hivemind_unavailable(): """Test that we raise an exception when Hivemind is not available.""" with pytest.raises(MisconfigurationException, match="you must have Hivemind installed"): - CollaborativeStrategy(target_batch_size=1) + HivemindStrategy(target_batch_size=1) @RunIf(hivemind=True) @mock.patch("hivemind.DHT", autospec=True) def test_strategy(mock_dht): - strategy = CollaborativeStrategy(target_batch_size=1) + strategy = HivemindStrategy(target_batch_size=1) trainer = pl.Trainer(strategy=strategy) assert trainer.strategy == strategy @@ -50,7 +50,7 @@ def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, p """Test when we are a second peer (passing initial peers or peer endpoint) we warn the user that logging/checkpointing will be disabled.""" with pytest.warns(UserWarning, match="This machine is not a persistent machine"): - CollaborativeStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint) + HivemindStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint) @RunIf(hivemind=True) @@ -65,7 +65,7 @@ def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, p ) def test_initial_peer_message(caplog, endpoint, expected_message): model = BoringModel() - trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True) trainer.fit(model) assert expected_message in caplog.text @@ -79,7 +79,7 @@ def test_optimizer_wrapped(): assert isinstance(optimizer, hivemind.Optimizer) model = TestModel() - trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) trainer.fit(model) @@ -97,7 +97,7 @@ def test_scheduler_wrapped(): model = TestModel() trainer = pl.Trainer( - strategy=CollaborativeStrategy(target_batch_size=1), + strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True, ) trainer.fit(model) @@ -121,7 +121,7 @@ def test_scheduler_wrapped(): @mock.patch("http.server.ThreadingHTTPServer", autospec=True) def test_env_variables_parsed(mock_dht, mock_peers, mock_server): """Test that env variables are parsed correctly.""" - strategy = CollaborativeStrategy(target_batch_size=1) + strategy = HivemindStrategy(target_batch_size=1) assert strategy.dht_manager._initial_peers == ["TEST_PEERS"] assert strategy.dht_manager._host == "TEST_HOST" assert strategy.dht_manager._port == 1300 @@ -143,9 +143,7 @@ def test_reuse_grad_buffers_warning(): pass model = TestModel() - trainer = pl.Trainer( - strategy=CollaborativeStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True - ) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True) with pytest.warns(UserWarning, match="You have overridden `optimizer_zero_grad` which will be disabled."): trainer.fit(model) @@ -162,7 +160,7 @@ def test_raise_exception_multiple_optimizers(): return [optimizer, optimizer], [lr_scheduler] model = TestModel() - trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) with pytest.raises(MisconfigurationException, match="Hivemind only supports training with one optimizer."): trainer.fit(model) @@ -174,7 +172,7 @@ def test_raise_exception_no_batch_size(mock_extract_batch_size): """Test that we raise an exception when no batch size is automatically found.""" model = BoringModel() - trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) with pytest.raises(MisconfigurationException, match="Please provide the batch size to the Strategy."): trainer.fit(model) @@ -191,7 +189,7 @@ def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, de function.""" model = BoringModel() trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=1, delay_grad_averaging=delay_grad_averaging, delay_state_averaging=delay_state_averaging, @@ -207,7 +205,7 @@ def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, de @RunIf(hivemind=True) @mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) @mock.patch("http.server.ThreadingHTTPServer", autospec=True) -@mock.patch("pytorch_lightning.strategies.collaborative.CollaborativeStrategy.num_peers", new_callable=PropertyMock) +@mock.patch("pytorch_lightning.strategies.collaborative.HivemindStrategy.num_peers", new_callable=PropertyMock) def test_args_passed_to_optimizer(mock_peers, mock_server): """Test to ensure arguments are correctly passed to the hivemind optimizer wrapper.""" mock_peers.return_value = 1 @@ -234,7 +232,7 @@ def test_args_passed_to_optimizer(mock_peers, mock_server): model = TestModel() trainer = pl.Trainer( - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( target_batch_size=1, reuse_grad_buffers=True, delay_state_averaging=True, @@ -258,7 +256,7 @@ def test_args_passed_to_optimizer(mock_peers, mock_server): ) def test_maddrs(host_maddrs, expected_maddrs): """Test that the multiple addresses are correctly assigned.""" - strategy = CollaborativeStrategy(target_batch_size=1, host_maddrs=host_maddrs) + strategy = HivemindStrategy(target_batch_size=1, host_maddrs=host_maddrs) assert strategy.dht.kwargs["host_maddrs"] == expected_maddrs @@ -281,7 +279,7 @@ def _run_collab_training_fn(initial_peers, wait_seconds, barrier, recorded_proce max_epochs=1, limit_train_batches=16, limit_val_batches=0, - strategy=CollaborativeStrategy( + strategy=HivemindStrategy( delay_state_averaging=True, offload_optimizer=True, delay_optimizer_step=True, @@ -347,7 +345,7 @@ def test_scaler_updated_precision_16(): model = TestModel() trainer = pl.Trainer( - strategy=CollaborativeStrategy(target_batch_size=1), + strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True, precision=16, accelerator="gpu", @@ -362,7 +360,7 @@ def test_raise_when_peer_endpoint_unsuccessful(caplog): port = find_free_network_port() with pytest.raises(MisconfigurationException, match="Unable to get peers"): with mock.patch("requests.get", wraps=requests.get) as requests_mock: - CollaborativeStrategy( + HivemindStrategy( target_batch_size=1, peer_endpoint=f"localhost:{port}", retry_endpoint_attempts=10,