diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 487628c3f6..a5863427ef 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -283,7 +283,6 @@ strategies DDPStrategy DataParallelStrategy DeepSpeedStrategy - HivemindStrategy HPUParallelStrategy IPUStrategy ParallelStrategy diff --git a/docs/source-pytorch/common_usecases.rst b/docs/source-pytorch/common_usecases.rst index 307a32f03b..606eea96d7 100644 --- a/docs/source-pytorch/common_usecases.rst +++ b/docs/source-pytorch/common_usecases.rst @@ -123,13 +123,6 @@ Customize and extend Lightning for things like custom hardware or distributed st :button_link: clouds/cloud_training.html :height: 100 -.. displayitem:: - :header: Train on multiple machines over the internet - :description: Train on local machines or unreliable GPUs across the internet. - :col_css: col-md-12 - :button_link: strategies/hivemind - :height: 100 - .. displayitem:: :header: Train on single or multiple GPUs :description: Train models faster with GPUs. diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 82d1d5e103..abc5b87300 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -72,9 +72,6 @@ The below table lists all relevant strategies available in Lightning with their * - bagua - :class:`~pytorch_lightning.strategies.BaguaStrategy` - Strategy for training using the Bagua library, with advanced distributed training algorithms and system optimizations. :ref:`Learn more. ` - * - collaborative - - :class:`~pytorch_lightning.strategies.HivemindStrategy` - - Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. ` * - colossalai - :class:`~pytorch_lightning.strategies.ColossalAIStrategy` - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst index b1a490ae40..3f743b448b 100644 --- a/docs/source-pytorch/index.rst +++ b/docs/source-pytorch/index.rst @@ -200,7 +200,6 @@ Current Lightning Users clouds/cluster Save and load model progress Save memory with half-precision - Training over the internet advanced/model_parallel clouds/cloud_training Train on single or multiple GPUs @@ -246,7 +245,6 @@ Current Lightning Users Metrics Model Model Parallel - Collaborative Training Plugins Progress bar Production diff --git a/docs/source-pytorch/strategies/hivemind.rst b/docs/source-pytorch/strategies/hivemind.rst deleted file mode 100644 index 5695f5695f..0000000000 --- a/docs/source-pytorch/strategies/hivemind.rst +++ /dev/null @@ -1,44 +0,0 @@ -.. _hivemind: - -##################################################### -Training on unreliable mixed GPUs across the internet -##################################################### -**Audience:** Users who do not have access to top-tier multi-gpu/multi-node servers and want to scale training across different GPU types, or across the internet. - ----- - -.. raw:: html - -
-
- -.. Add callout items below this line -.. displayitem:: - :header: 1: Training across multiple machines over the internet - :description: Quick setup to start training on multiple machines. - :col_css: col-md-4 - :button_link: hivemind_basic.html - :height: 200 - :tag: basic - -.. displayitem:: - :header: 2: Speed up training by enabling under-the-hood optimizations - :description: Learn which flags to use with the HivemindStrategy to speed up training. - :col_css: col-md-4 - :button_link: hivemind_intermediate.html - :height: 200 - :tag: intermediate - -.. displayitem:: - :header: 3: Optimize Memory and Communication using compression hooks - :description: Enable gradient buffer optimizations and communication improvements to reduce bottlenecks in communication. - :col_css: col-md-4 - :button_link: hivemind_expert.html - :height: 200 - :tag: expert - - -.. raw:: html - -
-
diff --git a/docs/source-pytorch/strategies/hivemind_basic.rst b/docs/source-pytorch/strategies/hivemind_basic.rst deleted file mode 100644 index 98e90cbfe9..0000000000 --- a/docs/source-pytorch/strategies/hivemind_basic.rst +++ /dev/null @@ -1,43 +0,0 @@ -:orphan: - -.. _hivemind_basic: - -Training on unreliable mixed GPUs across the internet (Basic) -============================================================= - -Collaborative Training tries to solve the need for top-tier multi-GPU servers by allowing you to train across unreliable machines, -such as local machines or even preemptible cloud compute across the internet. - -Under the hood, we use `Hivemind `_ which provides de-centralized training across the internet. - -To use Collaborative Training, you need to first install Hivemind. - -.. code-block:: bash - - pip install hivemind - -The ``HivemindStrategy`` accumulates gradients from all processes that are collaborating until they reach a ``target_batch_size``. By default, we use the batch size -of the first batch to determine what each local machine batch contributes towards the ``target_batch_size``. Once the ``target_batch_size`` is reached, an optimizer step -is made on all processes. - -.. warning:: - - When using ``HivemindStrategy`` note that you cannot use gradient accumulation (``accumulate_grad_batches``). This is because Hivemind manages accumulation internally. - -.. code-block:: python - - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=8192), accelerator="gpu", devices=1) - -.. code-block:: bash - - python train.py - # Other machines can connect running the same command: - # INITIAL_PEERS=... python train.py - # or passing the peers to the strategy:" - # HivemindStrategy(initial_peers=...)" - - -A helper message is printed once your training begins, which shows you how to start training on other machines using the same code. diff --git a/docs/source-pytorch/strategies/hivemind_expert.rst b/docs/source-pytorch/strategies/hivemind_expert.rst deleted file mode 100644 index 3fa55afb13..0000000000 --- a/docs/source-pytorch/strategies/hivemind_expert.rst +++ /dev/null @@ -1,87 +0,0 @@ -:orphan: - -.. _hivemind_expert: - -Training on unreliable mixed GPUs across the internet (Expert) -============================================================== - -Using Compression to Optimize Communications -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Below are some ways to reduce communication when training collaboratively. As the size of your model increase, bottlenecks in communication become more apparent. - -Compress Gradients & State -"""""""""""""""""""""""""" - -Hivemind allows you to compress gradients and states before sending them to other machines. This helps reduce the communication overhead substantially when training across the internet. - -Below, we enable Float16 compression, which compresses gradients and states to Float16 before sending it to other machines. - -.. note:: - Compressing gradients can affect convergence if you're lowering the precision (i.e training in Float32, but compressing gradients to FP16). - -.. code-block:: python - - from hivemind import Float16Compression - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=target_batch_size, - grad_compression=Float16Compression(), - state_averaging_compression=Float16Compression(), - ), - accelerator="gpu", - devices=1, - ) - -A slightly more advanced scheme is dynamic compression based on value size. Below, we enable 8-bit quantization for large numbers, and Float16 compression for small values, reducing communication bottlenecks even further. - -Size Adaptive Compression has been used in a variety of Hivemind applications and has shown success, but does quantize gradients further, meaning we lose precision when compressing. - -.. code-block:: python - - from hivemind import Float16Compression, Uniform8BitQuantization - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - # compresses values above threshold with 8bit Quantization, lower with Float16 - compression = SizeAdaptiveCompression( - threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization() - ) - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=target_batch_size, - grad_compression=compression, - state_averaging_compression=compression, - ), - accelerator="gpu", - devices=1, - ) - - -PowerSGD -"""""""" - -`PowerSGD `_ is a technique to reduce distributed communication of gradients across processes. -In short, PowerSGD uses a low-rank approximation to compress gradients before running an `all-reduce` step to sync gradients across all processes. - -.. note:: - Though PowerSGD can impact convergence, it can also substantially reduce communication between processes. - -.. code-block:: python - - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - from functools import partial - from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager - - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=8192, - grad_averager_factory=partial(PowerSGDGradientAverager, averager_rank=32, min_compression_ratio=0.5), - ), - accelerator="gpu", - devices=1, - ) diff --git a/docs/source-pytorch/strategies/hivemind_intermediate.rst b/docs/source-pytorch/strategies/hivemind_intermediate.rst deleted file mode 100644 index cec004219f..0000000000 --- a/docs/source-pytorch/strategies/hivemind_intermediate.rst +++ /dev/null @@ -1,99 +0,0 @@ -:orphan: - -.. _hivemind_intermediate: - -Training on unreliable mixed GPUs across the internet (Intermediate) -==================================================================== - -Reducing Communication By Overlapping Communication -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We can reduce the impact of communication across all machines by overlapping communication with our training iterations. In short, we enable communication to happen -in the background of training. - -Overlap Gradient and State Averaging -"""""""""""""""""""""""""""""""""""" - -When the target batch size is reached, all processes that are included in the step send gradients and model states to each other. By enabling some flags through -the strategy, communication can happen in the background. This allows training to continue (with slightly outdated weights) but provides us the means -to overlap communication with computation. - -.. warning:: - Enabling overlapping communication means convergence will slightly be affected. - -.. note:: - Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``HivemindStrategy`` instead of relying on a scheduler from ``configure_optimizers``. - The optimizer is re-created by Hivemind, and as a result, the scheduler has to be re-created. - -.. code-block:: python - - import torch - from functools import partial - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=8192, - delay_state_averaging=True, - delay_grad_averaging=True, - delay_optimizer_step=True, - offload_optimizer=True, # required to delay averaging - scheduler_fn=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=...), - ), - accelerator="gpu", - devices=1, - ) - - -Reducing GPU Memory requirements by re-using buffers & CPU offloading -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We can also offload the optimizer state to the CPU whilst re-using gradient buffers to reduce the memory requirement for machines. - -Offloading Optimizer State to the CPU -""""""""""""""""""""""""""""""""""""" - -Offloading the Optimizer state to the CPU works the same as :ref:`deepspeed-zero-stage-2-offload`, where we save GPU memory by keeping all optimizer states on the CPU. - -.. note:: - Enabling these flags means that you must pass in a ``scheduler_fn`` to the ``HivemindStrategy`` instead of relying on a scheduler from ``configure_optimizers``. - The optimizer is re-created by Hivemind, and as a result, the scheduler has to be re-created. - - We suggest enabling offloading and overlapping communication to hide the additional overhead from having to communicate with the CPU. - -.. code-block:: python - - import torch - from functools import partial - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=8192, - offload_optimizer=True, - scheduler_fn=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=...), - ), - accelerator="gpu", - devices=1, - ) - - -Re-using Gradient Buffers -""""""""""""""""""""""""" - -By default, Hivemind accumulates gradients in a separate buffer. This means additional GPU memory is required to store gradients. You can enable re-using the model parameter gradient buffers by passing ``reuse_grad_buffers=True`` to the ``HivemindStrategy``. - -.. warning:: - The ``HivemindStrategy`` will override ``zero_grad`` in your ``LightningModule`` to have no effect. This is because gradients are accumulated in the model - and Hivemind manages when they need to be cleared. - -.. code-block:: python - - import pytorch_lightning as pl - from pytorch_lightning.strategies import HivemindStrategy - - trainer = pl.Trainer( - strategy=HivemindStrategy(target_batch_size=8192, reuse_grad_buffers=True), accelerator="gpu", devices=1 - ) diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index d43e6d725b..4de4dc15f5 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -4,4 +4,3 @@ # colossalai>=0.1.10 # TODO: uncomment when there's a stable version released fairscale>=0.4.5, <0.4.13 deepspeed>=0.6.0, <=0.7.0 -hivemind==1.1.5; sys_platform == 'linux' diff --git a/src/pytorch_lightning/strategies/__init__.py b/src/pytorch_lightning/strategies/__init__.py index 5dc757fd65..dcfb11eecb 100644 --- a/src/pytorch_lightning/strategies/__init__.py +++ b/src/pytorch_lightning/strategies/__init__.py @@ -20,7 +20,6 @@ from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F4 from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401 from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401 from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401 -from pytorch_lightning.strategies.hivemind import HivemindStrategy # noqa: F401 from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy # noqa: F401 from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401 from pytorch_lightning.strategies.parallel import ParallelStrategy # noqa: F401 diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py deleted file mode 100644 index e3c9e9f649..0000000000 --- a/src/pytorch_lightning/strategies/hivemind.py +++ /dev/null @@ -1,332 +0,0 @@ -import ipaddress -import logging -import os -import platform -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from torch import Tensor - -import pytorch_lightning as pl -from lightning_fabric.utilities.types import LRScheduler, ReduceLROnPlateau -from pytorch_lightning.strategies.strategy import Strategy, TBroadcast -from pytorch_lightning.utilities.data import extract_batch_size -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_warn - -if _HIVEMIND_AVAILABLE: - import hivemind -else: - hivemind = None - -log = logging.getLogger(__name__) - - -class HivemindStrategy(Strategy): - INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS" - - def __init__( - self, - target_batch_size: int, - run_id: str = "lightning_run", - batch_size: Optional[int] = None, - delay_state_averaging: bool = False, - delay_optimizer_step: Optional[bool] = None, - delay_grad_averaging: bool = False, - offload_optimizer: Optional[bool] = None, - reuse_grad_buffers: bool = False, - scheduler_fn: Optional[Callable] = None, - matchmaking_time: float = 5.0, - averaging_timeout: float = 30.0, - verbose: bool = False, - averager_opts: Optional[Dict] = None, - host_maddrs: Optional[List] = None, - initial_peers: Optional[Union[str, List]] = None, - **optimizer_kwargs: Any, - ): - """Provides capabilities to train using the Hivemind Library, training collaboratively across the internet - with unreliable machines. For more information, `refer to the docs `__. - - .. warning:: ``HivemindStrategy`` is experimental and subject to change. - - Arguments: - - target_batch_size: When training, the batch size to accumulate to before running a step. The larger this - batch size, the more work can be done asynchronously without communication. - - run_id: A unique identifier of this training run, used as a common prefix for all DHT keys. - See ``https://learning-at-home.readthedocs.io/en/latest/user/dht.html``. - - batch_size: The local batch size per process. If not provided, we infer this from the first batch of data - passed in at training (lazy). Note that this should not change throughout training. - - delay_state_averaging: If enabled (default), average parameters and extra tensors in a background thread; - if set to False, average parameters synchronously within the - corresponding :meth:`hivemind.Optimizer.step` call. - - delay_optimizer_step: Run optimizer in background, apply results in future .step. requires - :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer`. - - delay_grad_averaging: Average gradients in background; requires - :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer` and - :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.delay_optimizer_step`. - - offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients. - - reuse_grad_buffers: Use the model's gradient buffers (params.grad) for gradient accumulation - which is more memory efficient. Lightning will automatically disable ``zero_grad`` - in the ``LightningModule``. - - scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler. - When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn`` - is required to be passed to the ``HivemindStrategy``. This is because the optimizer - is re-created and the scheduler needs to be re-created as well. - - matchmaking_time: When looking for group, wait for peers to join for up to this many seconds. - Increase if you see "averaged gradients with N peers" where N is below 0.9x on >=25% of epochs. - Training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes. - - averaging_timeout: If an averaging step hangs for this long, it will be cancelled automatically. - Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time. - Do not set this timeout too high, as it may cause your optimizer to hang - after some types of network errors. - - verbose: Report internal Hivemind events such as accumulating gradients and running background tasks. - - averager_opts: Additional keyword arguments forwarded to both - ``GradientAverager`` and ``TrainingStateAverager``. - - host_maddrs: List of multi-addrs to create visible peers for other processes. - `https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet` - - initial_peers: If connecting to a running process, a list of initial peers needs to be passed in. - This can also be set via the env variable ``INITIAL_PEERS``. - - **optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class. - """ - if not _HIVEMIND_AVAILABLE or platform.system() != "Linux": - raise MisconfigurationException( - "To use the `HivemindStrategy`, you must have Hivemind installed and be running on Linux." - " Install it by running `pip install -U hivemind`." - ) - - super().__init__() - self._initial_peers = initial_peers - self._target_batch_size = target_batch_size - self._batch_size = batch_size - self._scheduler_fn = scheduler_fn - self._require_scheduler_fn = delay_optimizer_step or delay_state_averaging or offload_optimizer - self._opt = None - self._optimizer_zero_grad_original: Optional[Callable] = None - self._run_id = run_id - self._reuse_grad_buffers = reuse_grad_buffers - self._optimizer_kwargs = dict( - matchmaking_time=matchmaking_time, - averaging_timeout=averaging_timeout, - delay_optimizer_step=delay_optimizer_step, - delay_state_averaging=delay_state_averaging, - delay_grad_averaging=delay_grad_averaging, - offload_optimizer=offload_optimizer, - averager_opts=averager_opts if averaging_timeout is not None else dict(request_timeout=1.0), - verbose=verbose, - reuse_grad_buffers=reuse_grad_buffers, - **optimizer_kwargs, - ) - - self._parse_env_initial_peers() - - self.dht = hivemind.DHT( - start=True, - initial_peers=initial_peers, - host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], - ) - - visible_addresses = [ - str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback - ] - - if initial_peers is None: - log.info( - "\nOther machines can connect running the same command:\n" - f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n" - "or passing the peers to the strategy:\n" - f"HivemindStrategy(initial_peers='{','.join(visible_addresses)}')" - ) - - self._hivemind_initialized = False - - def _parse_env_initial_peers(self) -> None: - initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers) - self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else self._initial_peers - - @property - def num_peers(self) -> int: - if self._opt: - return self._opt.tracker.global_progress.num_peers - return 1 - - @property - def root_device(self) -> torch.device: - from pytorch_lightning.accelerators.cpu import CPUAccelerator - from pytorch_lightning.accelerators.cuda import CUDAAccelerator - - if isinstance(self.accelerator, CUDAAccelerator): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif isinstance(self.accelerator, CPUAccelerator): - return torch.device("cpu") - raise MisconfigurationException( - f"Was unable to infer device type from the accelerator: {self.accelerator.__class__.__name__}." - ) - - @property - def global_rank(self) -> int: - return 0 - - @property - def is_global_zero(self) -> bool: - return True - - def setup(self, trainer: "pl.Trainer") -> None: - self.model_to_device() - super().setup(trainer) - if self.precision_plugin.precision == "16": - self.precision_plugin.scaler = hivemind.GradScaler() - - def _initialize_hivemind(self) -> None: - if len(self.optimizers) > 1: - raise MisconfigurationException("Hivemind only supports training with one optimizer.") - optimizer = self.optimizers[0] - - if self._require_scheduler_fn and self._scheduler_fn is None: - rank_zero_warn( - "Enabling `delay_optimizer_step`, `delay_state_averaging` or `offload_optimizer` " - "requires a `scheduler_fn` to be passed to the strategy if a scheduler is being used " - "(this is because the optimizer is re-created within Hivemind)." - ) - - scheduler = self._scheduler_fn if self._require_scheduler_fn else None - params = optimizer.param_groups if self._require_scheduler_fn else None - optimizer = type(optimizer) if self._require_scheduler_fn else optimizer - opt = hivemind.Optimizer( - dht=self.dht, - run_id=self._run_id, - params=params, - optimizer=optimizer, - scheduler=scheduler, - target_batch_size=self._target_batch_size, - batch_size_per_step=self._batch_size, - **self._optimizer_kwargs, - ) - - if not self._scheduler_fn: - self._wrap_schedulers(opt) - opt.load_state_from_peers() - self.optimizers = [opt] - self._opt = opt - - if self._reuse_grad_buffers: - assert self.lightning_module is not None - self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad - self._disable_zero_grad() - - def _disable_zero_grad(self) -> None: - lightning_module = self.lightning_module - if is_overridden("optimizer_zero_grad", lightning_module): - assert lightning_module is not None # `is_overridden` returns False otherwise - rank_zero_warn( - "You have overridden `optimizer_zero_grad` which will be disabled." - " When `HivemindStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad," - " as this would delete the gradients before they are averaged." - ) - assert lightning_module is not None - lightning_module.optimizer_zero_grad = None # type: ignore[assignment] - - def _wrap_schedulers(self, opt: "hivemind.Optimizer") -> None: - # wrap schedulers so that they only update when the hivemind optimizer updates - for scheduler_config in self.lr_scheduler_configs: - scheduler = scheduler_config.scheduler - if isinstance(scheduler, ReduceLROnPlateau): - raise ValueError( - f"The `ReduceLROnPlateau` scheduler is not currently supported with `{self.__class__.__name__}`." - ) - scheduler_config.scheduler = HiveMindScheduler( - optimizer=opt, - scheduler=scheduler, - ) - - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - if not self._hivemind_initialized: - self._hivemind_initialized = True - # todo (sean): we could technically support a dynamic batch size by inferring each step - # and passing it to the ``hivemind.Optimizer``. - if self._batch_size is None: - try: - self._batch_size = extract_batch_size(batch) - log.info(f"Found per machine batch size automatically from the batch: {self._batch_size}") - except (MisconfigurationException, RecursionError) as e: - raise MisconfigurationException( - "We tried to infer the batch size from the first batch of data. " - "Please provide the batch size to the Strategy by " - "``Trainer(strategy=HivemindStrategy(batch_size=x))``. " - ) from e - self._initialize_hivemind() - - def reduce(self, tensor: Union[Any, Tensor], *args: Any, **kwargs: Any) -> Union[Any, Tensor]: - return tensor - - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - return tensor - - def model_to_device(self) -> None: - assert self.model is not None - self.model.to(self.root_device) - - def barrier(self, *args: Any, **kwargs: Any) -> None: - pass - - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - return obj - - def teardown(self) -> None: - - if self._optimizer_zero_grad_original is not None and self.lightning_module is not None: - # re-enable `optimizer_zero_grad` - self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment] - - if self._opt: - self._opt.shutdown() - log.info("Shutting down hivemind DHT.") - self.dht.shutdown() - super().teardown() - - -class HiveMindScheduler: - """Wrapper for schedulers to prevent Lightning from stepping the scheduler too soon. - - This code ensures that we only step when the HiveMind optimizer reaches the global step. - """ - - base_lrs: List[float] - - def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None: - # copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has - # implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler` - self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")} - - self.optimizer = optimizer - self.scheduler = scheduler - self.current_step = -1 - - def step(self, epoch: Optional[int] = None) -> None: - while self.current_step < self.optimizer.local_epoch: - self.scheduler.step(epoch=epoch) - self.current_step += 1 - - def load_state_dict(self, state_dict: Dict) -> None: - self.scheduler.load_state_dict(state_dict) - - def state_dict(self) -> Dict: - return self.scheduler.state_dict() diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index e6737542b4..332d4db34b 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -20,7 +20,6 @@ from lightning_fabric.utilities import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401 from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 - _HIVEMIND_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index 95def583f7..2e841e6d45 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -26,7 +26,6 @@ _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _HABANA_FRAMEWORK_AVAILABLE = package_available("habana_frameworks") -_HIVEMIND_AVAILABLE = package_available("hivemind") _KINETO_AVAILABLE = torch.profiler.kineto_available() _OMEGACONF_AVAILABLE = package_available("omegaconf") _POPTORCH_AVAILABLE = package_available("poptorch") diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index d4d855389f..87f5b91f48 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -29,7 +29,6 @@ from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( - _HIVEMIND_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _OMEGACONF_AVAILABLE, @@ -71,7 +70,6 @@ class RunIf: bagua: bool = False, colossalai: bool = False, psutil: bool = False, - hivemind: bool = False, sklearn: bool = False, **kwargs, ): @@ -100,7 +98,6 @@ class RunIf: This requires that the ``PL_RUN_SLOW_TESTS=1`` environment variable is set. bagua: Require that BaguaSys/bagua is installed. psutil: Require that psutil is installed. - hivemind: Require that Hivemind is installed. sklearn: Require that scikit-learn is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ @@ -221,10 +218,6 @@ class RunIf: conditions.append(not _PSUTIL_AVAILABLE) reasons.append("psutil") - if hivemind: - conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin")) - reasons.append("Hivemind") - if sklearn: conditions.append(not _SKLEARN_AVAILABLE) reasons.append("scikit-learn") diff --git a/tests/tests_pytorch/strategies/test_hivemind.py b/tests/tests_pytorch/strategies/test_hivemind.py deleted file mode 100644 index a75d13676f..0000000000 --- a/tests/tests_pytorch/strategies/test_hivemind.py +++ /dev/null @@ -1,313 +0,0 @@ -import multiprocessing as mp -import os -import time -from typing import Any -from unittest import mock -from unittest.mock import PropertyMock - -import pytest -import torch -from torch import Tensor -from torch.optim import Optimizer - -import pytorch_lightning as pl -from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.strategies import HivemindStrategy -from pytorch_lightning.strategies.hivemind import HiveMindScheduler -from pytorch_lightning.utilities import _HIVEMIND_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT -from tests_pytorch.helpers.runif import RunIf - -if _HIVEMIND_AVAILABLE: - import hivemind - - -@mock.patch("pytorch_lightning.strategies.hivemind._HIVEMIND_AVAILABLE", False) -def test_raise_exception_if_hivemind_unavailable(): - """Test that we raise an exception when Hivemind is not available.""" - with pytest.raises(MisconfigurationException, match="you must have Hivemind installed"): - HivemindStrategy(target_batch_size=1) - - -@RunIf(hivemind=True) -@mock.patch("hivemind.DHT", autospec=True) -def test_strategy(mock_dht): - strategy = HivemindStrategy(target_batch_size=1) - trainer = pl.Trainer(strategy=strategy) - assert trainer.strategy == strategy - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -def test_optimizer_wrapped(): - class TestModel(BoringModel): - def on_before_backward(self, loss: Tensor) -> None: - optimizer = self.trainer.optimizers[0] - assert isinstance(optimizer, hivemind.Optimizer) - - model = TestModel() - trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) - trainer.fit(model) - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -def test_scheduler_wrapped(): - class TestModel(BoringModel): - def on_before_backward(self, loss: Tensor) -> None: - scheduler = self.trainer.lr_scheduler_configs[0].scheduler - assert isinstance(scheduler, HiveMindScheduler) - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return [optimizer], [torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)] - - model = TestModel() - trainer = pl.Trainer( - strategy=HivemindStrategy(target_batch_size=1), - fast_dev_run=True, - ) - trainer.fit(model) - - -@RunIf(hivemind=True) -@mock.patch.dict( - os.environ, - { - "HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", - "PL_INITIAL_PEERS": "TEST_PEERS", - }, - clear=True, -) -@mock.patch("hivemind.DHT", autospec=True) -def test_env_variables_parsed(mock_dht): - """Test that env variables are parsed correctly.""" - strategy = HivemindStrategy(target_batch_size=1) - assert strategy._initial_peers == ["TEST_PEERS"] - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -def test_reuse_grad_buffers_warning(): - """Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True.""" - - class TestModel(BoringModel): - def on_before_backward(self, loss: Tensor) -> None: - optimizer = self.trainer.optimizers[0] - assert isinstance(optimizer, hivemind.Optimizer) - - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): - pass - - model = TestModel() - trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True) - - with pytest.warns(UserWarning, match="You have overridden `optimizer_zero_grad` which will be disabled."): - trainer.fit(model) - - -@RunIf(hivemind=True) -def test_raise_exception_multiple_optimizers(): - """Test that we raise an exception when multiple optimizers are provided.""" - - class TestModel(BoringModel): - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer, optimizer], [lr_scheduler] - - model = TestModel() - trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) - - with pytest.raises(MisconfigurationException, match="Hivemind only supports training with one optimizer."): - trainer.fit(model) - - -@RunIf(hivemind=True) -@mock.patch("pytorch_lightning.utilities.data._extract_batch_size", autospec=True, return_value=[None]) -def test_raise_exception_no_batch_size(mock_extract_batch_size): - """Test that we raise an exception when no batch size is automatically found.""" - - model = BoringModel() - trainer = pl.Trainer(strategy=HivemindStrategy(target_batch_size=1), fast_dev_run=True) - - with pytest.raises(MisconfigurationException, match="Please provide the batch size to the Strategy."): - trainer.fit(model) - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -@pytest.mark.parametrize( - "delay_grad_averaging, delay_state_averaging, delay_optimizer_step", - [(True, True, True), (False, True, False)], -) -def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, delay_optimizer_step): - """Test ensures that valid combination of HiveMind delay arguments warn if scheduler isn't passed in as a - function.""" - model = BoringModel() - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=1, - delay_grad_averaging=delay_grad_averaging, - delay_state_averaging=delay_state_averaging, - delay_optimizer_step=delay_optimizer_step, - ), - fast_dev_run=True, - ) - - with pytest.warns(UserWarning, match="requires a `scheduler_fn` to be passed to the strategy"): - trainer.fit(model) - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -@mock.patch("pytorch_lightning.strategies.hivemind.HivemindStrategy.num_peers", new_callable=PropertyMock) -def test_args_passed_to_optimizer(mock_peers): - """Test to ensure arguments are correctly passed to the hivemind optimizer wrapper.""" - mock_peers.return_value = 1 - compression = hivemind.ScaledFloat16Compression() - with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer: - - class TestModel(BoringModel): - def on_before_backward(self, loss: Tensor) -> None: - args, kwargs = mock_optimizer.call_args - mock_optimizer.assert_called() - arguments = dict( - delay_optimizer_step=True, - delay_state_averaging=True, - state_averaging_compression=compression, - grad_compression=compression, - offload_optimizer=True, - reuse_grad_buffers=True, - target_batch_size=1, - ) - - for key, value in arguments.items(): - assert key in kwargs - assert value == kwargs[key] - - model = TestModel() - trainer = pl.Trainer( - strategy=HivemindStrategy( - target_batch_size=1, - reuse_grad_buffers=True, - delay_state_averaging=True, - delay_optimizer_step=True, - offload_optimizer=True, - grad_compression=compression, - state_averaging_compression=compression, - ), - fast_dev_run=True, - ) - trainer.fit(model) - # ensures that after training with `reuse_grad_buffers` we restore the hook - assert model.optimizer_zero_grad is not None - - -@RunIf(hivemind=True) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -@pytest.mark.parametrize( - "host_maddrs,expected_maddrs", - [(None, ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"]), (["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"])], -) -def test_maddrs(host_maddrs, expected_maddrs): - """Test that the multiple addresses are correctly assigned.""" - strategy = HivemindStrategy(target_batch_size=1, host_maddrs=host_maddrs) - assert strategy.dht.kwargs["host_maddrs"] == expected_maddrs - - -def _run_collab_training_fn(initial_peers, wait_seconds, barrier, recorded_process_peers, recorded_process_steps): - recorded_peers = [] - recorded_global_steps = [] - - class TestModel(BoringModel): - def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None: - time.sleep(wait_seconds) # add an additional delay to give processes time to sync - recorded_peers.append(self.trainer.strategy.num_peers) - recorded_global_steps.append(self.trainer.optimizers[0].local_epoch) - - def on_train_end(self) -> None: - # wait for all processes to get to the end of training before teardown - barrier.wait() - - model = TestModel() - trainer = pl.Trainer( - max_epochs=1, - limit_train_batches=16, - limit_val_batches=0, - strategy=HivemindStrategy( - delay_state_averaging=True, - offload_optimizer=True, - delay_optimizer_step=True, - delay_grad_averaging=True, - target_batch_size=8, - initial_peers=initial_peers, - verbose=False, - ), - ) - trainer.fit(model) - - recorded_process_peers.append(recorded_peers) - recorded_process_steps.append(recorded_global_steps) - - -# TODO: check why it fails with PT 1.12 -@RunIf(hivemind=True, max_torch="1.12") -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -@pytest.mark.parametrize( - "num_processes, wait_seconds", - [(2, 0.25)], -) -def test_multiple_peers(num_processes, wait_seconds): - """Test to ensure that if we have two running processes with the same peers, they connect and train - successfully.""" - dht_root = hivemind.DHT(start=True) - barrier = mp.Barrier(num_processes) - initial_peers = dht_root.get_visible_maddrs() - - with mp.Manager() as manager: - # allows processes to return their recorded logged peers/steps - recorded_process_peers = manager.list() - recorded_process_steps = manager.list() - processes = [ - mp.Process( - target=_run_collab_training_fn, - kwargs=dict( - initial_peers=initial_peers, - wait_seconds=wait_seconds, - barrier=barrier, - recorded_process_peers=recorded_process_peers, - recorded_process_steps=recorded_process_steps, - ), - ) - for x in range(num_processes) - ] - for process in processes: - process.start() - for process in processes: - process.join() - # assert that peers increase as expected and we run at-least 1 global step. - for process_peers, process_steps in zip(recorded_process_peers, recorded_process_steps): - assert any(num_peer == num_processes for num_peer in process_peers) - assert any(global_step > 0 for global_step in process_steps) - - -@RunIf(hivemind=True, min_cuda_gpus=1) -@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) -def test_scaler_updated_precision_16(): - class TestModel(BoringModel): - def on_fit_start(self) -> None: - assert isinstance(self.trainer.precision_plugin.scaler, hivemind.GradScaler) - raise SystemExit - - model = TestModel() - trainer = pl.Trainer( - strategy=HivemindStrategy(target_batch_size=1), - fast_dev_run=True, - precision=16, - accelerator="gpu", - devices=1, - ) - with pytest.raises(SystemExit): - trainer.fit(model)