From 1a502c061cf43f60cbf0f2089981ef6b8a57419c Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 5 May 2022 17:06:26 +0100 Subject: [PATCH] [1/2] Collaborative Strategy (#12842) --- CHANGELOG.md | 3 + pytorch_lightning/strategies/__init__.py | 1 + pytorch_lightning/strategies/collaborative.py | 529 ++++++++++++++++++ pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/imports.py | 1 + pytorch_lightning/utilities/types.py | 6 + requirements/strategies.txt | 1 + tests/helpers/runif.py | 7 + tests/strategies/test_collaborative.py | 372 ++++++++++++ 9 files changed, 921 insertions(+) create mode 100644 pytorch_lightning/strategies/collaborative.py create mode 100644 tests/strategies/test_collaborative.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bccace101..0a39151dee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588)) +- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842)) + + - Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902)) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index ca6f28338e..0de3e51a0f 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401 +from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401 from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401 from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401 from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401 diff --git a/pytorch_lightning/strategies/collaborative.py b/pytorch_lightning/strategies/collaborative.py new file mode 100644 index 0000000000..5774fdaca5 --- /dev/null +++ b/pytorch_lightning/strategies/collaborative.py @@ -0,0 +1,529 @@ +import http +import ipaddress +import logging +import os +import platform +import re +import threading +import time +import warnings +from http.server import BaseHTTPRequestHandler +from typing import Any, Callable, Dict, List, Optional, Union + +import requests +import torch + +import pytorch_lightning as pl +from pytorch_lightning.strategies.strategy import Strategy, TBroadcast +from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.data import extract_batch_size +from pytorch_lightning.utilities.enums import PrecisionType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _LRScheduler, ReduceLROnPlateau + +if _HIVEMIND_AVAILABLE: + import hivemind + +log = logging.getLogger(__name__) + + +class CollaborativeStrategy(Strategy): + def __init__( + self, + target_batch_size: int, + run_id: str = "lightning_run", + batch_size: Optional[int] = None, + delay_state_averaging: bool = False, + delay_optimizer_step: Optional[bool] = None, + delay_grad_averaging: bool = False, + offload_optimizer: Optional[bool] = None, + reuse_grad_buffers: bool = False, + scheduler_fn: Optional[Callable] = None, + matchmaking_time: float = 5.0, + averaging_timeout: float = 30.0, + verbose: bool = False, + averager_opts: Optional[Dict] = None, + host_maddrs: Optional[List] = None, + initial_peers: Optional[Union[str, List]] = None, + endpoint: Optional[bool] = None, + peer_endpoint: Optional[str] = None, + persistent: bool = True, + host: Optional[str] = None, + port: Optional[int] = None, + retry_endpoint_attempts: int = 5, + retry_endpoint_sleep_duration: int = 5, + **optimizer_kwargs: Any, + ): + """Provides capabilities to train using the Hivemind Library, training collaboratively across the internet + with unreliable machines. For more information, `refer to the docs `__. + + .. warning:: ``CollaborativeStrategy`` is experimental and subject to change. + + Arguments: + + target_batch_size: When training, the batch size to accumulate to before running a step. The larger this + batch size, the more work can be done asynchronously without communication. + + run_id: A unique identifier of this training run, used as a common prefix for all DHT keys. + See ``https://learning-at-home.readthedocs.io/en/latest/user/dht.html``. + + batch_size: The local batch size per process. If not provided, we infer this from the first batch of data + passed in at training (lazy). Note that this should not change throughout training. + + delay_state_averaging: If enabled (default), average parameters and extra tensors in a background thread; + if set to False, average parameters synchronously within the + corresponding :meth:`hivemind.Optimizer.step` call. + + delay_optimizer_step: Run optimizer in background, apply results in future .step. requires + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer`. + + delay_grad_averaging: Average gradients in background; requires + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer` and + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.delay_optimizer_step`. + + offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients. + + reuse_grad_buffers: Use the model's gradient buffers (params.grad) for gradient accumulation + which is more memory efficient. Lightning will automatically disable ``zero_grad`` + in the ``LightningModule``. + + scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler. + When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn`` + is required to be passed to the ``CollaborativeStrategy``. This is because the optimizer + is re-created and the scheduler needs to be re-created as well. + + matchmaking_time: When looking for group, wait for peers to join for up to this many seconds. + Increase if you see "averaged gradients with N peers" where N is below 0.9x on >=25% of epochs. + Training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes. + + averaging_timeout: If an averaging step hangs for this long, it will be cancelled automatically. + Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time. + Do not set this timeout too high, as it may cause your optimizer to hang + after some types of network errors. + + verbose: Report internal Hivemind events such as accumulating gradients and running background tasks. + + averager_opts: Additional keyword arguments forwarded to both + ``GradientAverager`` and ``TrainingStateAverager``. + + host_maddrs: List of multi-addrs to create visible peers for other processes. + `https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet` + + initial_peers: If connecting to a running process, a list of initial peers needs to be passed in. + This can also be set via the env variable ``INITIAL_PEERS``. + + endpoint: Enable if a side-car endpoint server is required on the process to server initial peers. + This is useful when using some form of orchestration such as torchelastic. + + peer_endpoint: The endpoint to request initial peers from. + + persistent: When using an endpoint, this controls whether other processes that are not the endpoint + server log/checkpoint. If ``persistent`` is True, we do not log/checkpoint from other processes. + + host: When creating the endpoint, the host IP to use. + + port: When creating the endpoint, the host port to use. + + retry_endpoint_attempts: When connecting to the + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`, + how many time to retry before raising an exception. + + retry_endpoint_sleep_duration: When connecting to the + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`, + how long to wait between retries. + + **optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class. + """ + if not _HIVEMIND_AVAILABLE or platform.system() != "Linux": + raise MisconfigurationException( + "To use the `CollaborativeStrategy`, you must have Hivemind installed and be running on Linux." + " Install it by running `pip install -U hivemind`." + ) + + super().__init__() + self.dht_manager = DHTManager( + persistent=persistent, + endpoint=endpoint, + peer_endpoint=peer_endpoint, + host=host, + port=port, + host_maddrs=host_maddrs, + initial_peers=initial_peers, + retry_endpoint_attempts=retry_endpoint_attempts, + retry_endpoint_sleep_duration=retry_endpoint_sleep_duration, + ) + self._target_batch_size = target_batch_size + self._batch_size = batch_size + self._scheduler_fn = scheduler_fn + self._require_scheduler_fn = delay_optimizer_step or delay_state_averaging or offload_optimizer + self._opt = None + self._optimizer_zero_grad_original: Optional[Callable] = None + self._run_id = run_id + self._reuse_grad_buffers = reuse_grad_buffers + self._optimizer_kwargs = dict( + matchmaking_time=matchmaking_time, + averaging_timeout=averaging_timeout, + delay_optimizer_step=delay_optimizer_step, + delay_state_averaging=delay_state_averaging, + delay_grad_averaging=delay_grad_averaging, + offload_optimizer=offload_optimizer, + averager_opts=averager_opts if averaging_timeout is not None else dict(request_timeout=1.0), + verbose=verbose, + reuse_grad_buffers=reuse_grad_buffers, + **optimizer_kwargs, + ) + + # a bit of a hack to only log from the stable server + if self.dht_manager.disable_logging_checkpointing: + warnings.warn( + "This machine is not a persistent machine. Checkpointing/Logging has been disabled.", UserWarning + ) + rank_zero_only.rank = 1 if self.dht_manager.disable_logging_checkpointing else 0 + self._hivemind_initialized = False + + @property + def num_peers(self) -> int: + if self._opt: + return self._opt.tracker.global_progress.num_peers + return 1 + + @property + def dht(self) -> "hivemind.DHT": + """Hivemind Distributed Hash Table which stores values across all peers. + + See documentation for more details: `https://learning-at-home.readthedocs.io/en/latest/modules/dht.html` + """ + return self.dht_manager.dht + + @property + def root_device(self) -> torch.device: + from pytorch_lightning.accelerators.cpu import CPUAccelerator + from pytorch_lightning.accelerators.gpu import GPUAccelerator + + if isinstance(self.accelerator, GPUAccelerator): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif isinstance(self.accelerator, CPUAccelerator): + return torch.device("cpu") + raise MisconfigurationException( + f"Was unable to infer device type from the accelerator: {self.accelerator.__class__.__name__}." + ) + + @property + def global_rank(self) -> int: + return 0 + + @property + def is_global_zero(self) -> bool: + return True + + def setup(self, trainer: "pl.Trainer") -> None: + self.model_to_device() + super().setup(trainer) + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): + self.precision_plugin.scaler = hivemind.GradScaler() + + def _initialize_hivemind(self) -> None: + if len(self.optimizers) > 1: + raise MisconfigurationException("Hivemind only supports training with one optimizer.") + optimizer = self.optimizers[0] + + if self._require_scheduler_fn and self._scheduler_fn is None: + rank_zero_warn( + "Enabling `delay_optimizer_step`, `delay_state_averaging` or `offload_optimizer` " + "requires a `scheduler_fn` to be passed to the strategy if a scheduler is being used " + "(this is because the optimizer is re-created within Hivemind)." + ) + + scheduler = self._scheduler_fn if self._require_scheduler_fn else None + params = optimizer.param_groups if self._require_scheduler_fn else None + optimizer = type(optimizer) if self._require_scheduler_fn else optimizer + opt = hivemind.Optimizer( + dht=self.dht, + run_id=self._run_id, + params=params, + optimizer=optimizer, + scheduler=scheduler, + target_batch_size=self._target_batch_size, + batch_size_per_step=self._batch_size, + **self._optimizer_kwargs, + ) + + if not self._scheduler_fn: + self._wrap_schedulers(opt) + opt.load_state_from_peers() + self.optimizers = [opt] + self._opt = opt + + if self._reuse_grad_buffers: + assert self.lightning_module is not None + self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad + self._disable_zero_grad() + + def _disable_zero_grad(self) -> None: + lightning_module = self.lightning_module + if is_overridden("optimizer_zero_grad", lightning_module): + assert lightning_module is not None # `is_overridden` returns False otherwise + rank_zero_warn( + "You have overridden `optimizer_zero_grad` which will be disabled." + " When `CollaborativeStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad," + " as this would delete the gradients before they are averaged." + ) + assert lightning_module is not None + lightning_module.optimizer_zero_grad = None # type: ignore[assignment] + + def _wrap_schedulers(self, opt: "hivemind.Optimizer") -> None: + # wrap schedulers so that they only update when the hivemind optimizer updates + for scheduler_config in self.lr_scheduler_configs: + scheduler = scheduler_config.scheduler + if isinstance(scheduler, ReduceLROnPlateau): + raise ValueError( + f"The `ReduceLROnPlateau` scheduler is not currently supported with `{self.__class__.__name__}`." + ) + scheduler_config.scheduler = HiveMindScheduler( + optimizer=opt, + scheduler=scheduler, + ) + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + if not self._hivemind_initialized: + self._hivemind_initialized = True + # todo (sean): we could technically support a dynamic batch size by inferring each step + # and passing it to the ``hivemind.Optimizer``. + if self._batch_size is None: + try: + self._batch_size = extract_batch_size(batch) + log.info(f"Found per machine batch size automatically from the batch: {self._batch_size}") + except (MisconfigurationException, RecursionError) as e: + raise MisconfigurationException( + "We tried to infer the batch size from the first batch of data. " + "Please provide the batch size to the Strategy by " + "``Trainer(strategy=CollaborativeStrategy(batch_size=x))``. " + ) from e + self._initialize_hivemind() + + def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + return tensor + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + return tensor + + def model_to_device(self) -> None: + assert self.model is not None + self.model.to(self.root_device) + + def barrier(self, *args: Any, **kwargs: Any) -> None: + pass + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + return obj + + def teardown(self) -> None: + super().teardown() + + if self._optimizer_zero_grad_original is not None and self.lightning_module is not None: + # re-enable `optimizer_zero_grad` + self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment] + + if self.root_device.type == "cuda" and self.lightning_module is not None: + # GPU teardown + self.lightning_module.cpu() + # clean up memory + torch.cuda.empty_cache() + if self._opt: + self._opt.shutdown() + log.info("Shutting down hivemind DHT.") + self.dht.shutdown() + + +class HiveMindScheduler: + """Wrapper for schedulers to prevent Lightning from stepping the scheduler too soon. + + This code ensures that we only step when the HiveMind optimizer reaches the global step. + """ + + def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None: + # copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has + # implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler` + self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")} + + self.optimizer = optimizer + self.scheduler = scheduler + self.current_step = -1 + + def step(self, epoch: Optional[int] = None) -> None: + while self.current_step < self.optimizer.local_epoch: + self.scheduler.step(epoch=epoch) + self.current_step += 1 + + def load_state_dict(self, state_dict: Dict) -> None: + self.scheduler.load_state_dict(state_dict) + + def state_dict(self) -> Dict: + return self.scheduler.state_dict() + + +class DHTManager: + ENDPOINT_ENV: str = "PL_ENDPOINT" + PEER_ENDPOINT_ENV: str = "PL_PEER_ENDPOINT" + INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS" + HOST_ENV: str = "PL_HOST" + PORT_ENV: str = "PL_PORT" + DEFAULT_HOST: str = "0.0.0.0" + DEFAULT_PORT: int = 1440 + + def __init__( + self, + host_maddrs: Optional[List], + initial_peers: Optional[Union[str, List]], + persistent: bool, + endpoint: Optional[bool], + peer_endpoint: Optional[str], + host: Optional[str], + port: Optional[int], + retry_endpoint_attempts: int = 5, + retry_endpoint_sleep_duration: int = 5, + ) -> None: + """Manages the `hivemind.DHT` connection and provides a side-car endpoint server for initial peer access. + + Arguments: + + host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host_maddrs` + + initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.initial_peers` + + persistent: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.persistent` + + endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.endpoint` + + peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint` + + host: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host` + + port: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.port` + + retry_endpoint_attempts: + :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_attempts` + + retry_endpoint_sleep_duration: + :paramref: + `~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_sleep_duration` + """ + self._persistent = persistent + self._endpoint = endpoint + self._initial_peers = initial_peers + self._peer_endpoint = peer_endpoint + self._host = host + self._port = port + + self._parse_env_vars() + + if self._peer_endpoint and self._initial_peers is None: + self._initial_peers = self._get_initial_peers_from_endpoint( + retry_initial_peers=retry_endpoint_attempts, retry_peer_sleep_duration=retry_endpoint_sleep_duration + ) + + self.dht = hivemind.DHT( + start=True, + initial_peers=self._initial_peers, + host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], + ) + + visible_addresses = [ + str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback + ] + + if self._endpoint: + self._host = self._host if self._host is not None else self.DEFAULT_HOST + self._port = self._port if self._port is not None else self.DEFAULT_PORT + self._start_server_process(self._host, self._port) + self._log_endpoint_helper_message(visible_addresses) + elif self._peer_endpoint: + log.info("Machine received initial peers from endpoint.") + elif self._initial_peers is None: + log.info( + "\nOther machines can connect running the same command:\n" + f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n" + "or passing the peers to the strategy:\n" + f"CollaborativeStrategy(initial_peers='{','.join(visible_addresses)}')" + ) + + def _log_endpoint_helper_message(self, visible_addresses: List[str]) -> None: + assert self._host is not None + resolved_host = self._host + if "0.0.0.0" in self._host: + # use the visible multi-addresses to figure out the IP that has been exposed + # todo (sean): this is pretty hacky, worth investigating. + p = re.compile(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+") + # todo (sean): we select one address from here, could we have multiple? + resolved_host = {p.findall(maddr)[0] for maddr in visible_addresses}.pop() + log.info( + "\nSidecar endpoint enabled to serve peers.\n" + "Other peers can connect via:\n" + f"PEER_ENDPOINT={resolved_host}:{self._port} python ...\n" + "or pass the peer endpoint address to the strategy:\n" + f"CollaborativeStrategy(peer_endpoint='{resolved_host}:{self._port}')" + ) + + def _start_server_process(self, host: str, port: int) -> None: + dht = self.dht + + class DHTHandler(BaseHTTPRequestHandler): + def do_GET(self) -> None: + """Respond to a GET request.""" + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + visible_peers = [ + str(a) for a in dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback + ] + + self.wfile.write("\n".join(visible_peers).encode()) + + server = http.server.ThreadingHTTPServer((host, int(port)), DHTHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + def _get_initial_peers_from_endpoint(self, retry_initial_peers: int, retry_peer_sleep_duration: int) -> List: + peers = None + for _ in range(retry_initial_peers): + try: + peers = self._get_peers() + break + except requests.exceptions.RequestException: + log.info(f"Failed to get peers, retrying in {retry_peer_sleep_duration} seconds...") + time.sleep(retry_peer_sleep_duration) + if peers is None: + raise MisconfigurationException( + f"Unable to get peers. Tried {retry_initial_peers} times waiting {retry_peer_sleep_duration}s." + f"These parameters can be extended by passing " + "to the strategy (CollaborativeStrategy(retry_connection=x, retry_sleep_duration=y))." + ) + log.info(f"Received initial peers from collaborative server: {peers}") + return peers + + def _get_peers(self) -> List[str]: + assert self._peer_endpoint is not None + url = f"http://{self._peer_endpoint}" if not self._peer_endpoint.startswith("http://") else self._peer_endpoint + r = requests.get(url) + return r.text.split(",") + + def _parse_env_vars(self) -> None: + endpoint = os.environ.get(self.ENDPOINT_ENV, self._endpoint) + self._endpoint = endpoint == "1" if isinstance(endpoint, str) else endpoint + self._peer_endpoint = os.environ.get(self.PEER_ENDPOINT_ENV, self._peer_endpoint) + initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers) + self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else initial_peers + + port = os.environ.get(self.PORT_ENV, self._port) + self._port = int(port) if isinstance(port, str) else port + self._host = os.environ.get(self.HOST_ENV, self._host) + + @property + def disable_logging_checkpointing(self) -> bool: + # if this node is a peer, we do not log/checkpoint in persistent mode. + return self._persistent and (self._initial_peers is not None or self._peer_endpoint is not None) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 87947ac9a1..61e075fd01 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -35,6 +35,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _GROUP_AVAILABLE, + _HIVEMIND_AVAILABLE, _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _HYDRA_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 835e56f181..f2f73a7d89 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -105,6 +105,7 @@ _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group") +_HIVEMIND_AVAILABLE = _package_available("hivemind") _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _package_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index c5e384117f..c65ac4b39e 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -68,6 +68,9 @@ class _LRScheduler(_Stateful, Protocol): def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ... + def step(self, epoch: Optional[int] = None) -> None: + ... + # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @@ -91,6 +94,9 @@ class ReduceLROnPlateau(_Stateful, Protocol): ) -> None: ... + def step(self, metrics: Union[float, int, torch.Tensor], epoch: Optional[int] = None) -> None: + ... + # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) diff --git a/requirements/strategies.txt b/requirements/strategies.txt index ae6648a6eb..7846a297e3 100644 --- a/requirements/strategies.txt +++ b/requirements/strategies.txt @@ -1,3 +1,4 @@ fairscale>=0.4.5 deepspeed<0.6.0 horovod>=0.21.2,!=0.24.0 # no need to install with [pytorch] as pytorch is already installed +hivemind>=1.0.1; sys_platform == 'linux' diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 5a2464f6fd..99d64ebd01 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities import ( _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, + _HIVEMIND_AVAILABLE, _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, @@ -84,6 +85,7 @@ class RunIf: omegaconf: bool = False, slow: bool = False, bagua: bool = False, + hivemind: bool = False, **kwargs, ): """ @@ -111,6 +113,7 @@ class RunIf: omegaconf: Require that omry/omegaconf is installed. slow: Mark the test as slow, our CI will run it in a separate job. bagua: Require that BaguaSys/bagua is installed. + hivemind: Require that Hivemind is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] @@ -231,6 +234,10 @@ class RunIf: conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin")) reasons.append("Bagua") + if hivemind: + conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin")) + reasons.append("Hivemind") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/strategies/test_collaborative.py b/tests/strategies/test_collaborative.py new file mode 100644 index 0000000000..6787f30c38 --- /dev/null +++ b/tests/strategies/test_collaborative.py @@ -0,0 +1,372 @@ +import multiprocessing as mp +import os +import time +from typing import Any +from unittest import mock +from unittest.mock import PropertyMock + +import pytest +import requests +import torch +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port +from pytorch_lightning.strategies import CollaborativeStrategy +from pytorch_lightning.strategies.collaborative import HiveMindScheduler +from pytorch_lightning.utilities import _HIVEMIND_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import STEP_OUTPUT +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + +if _HIVEMIND_AVAILABLE: + import hivemind + + +@mock.patch("pytorch_lightning.strategies.collaborative._HIVEMIND_AVAILABLE", False) +def test_raise_exception_if_hivemind_unavailable(): + """Test that we raise an exception when Hivemind is not available.""" + with pytest.raises(MisconfigurationException, match="you must have Hivemind installed"): + CollaborativeStrategy(target_batch_size=1) + + +@RunIf(hivemind=True) +@mock.patch("hivemind.DHT", autospec=True) +def test_strategy(mock_dht): + strategy = CollaborativeStrategy(target_batch_size=1) + trainer = pl.Trainer(strategy=strategy) + assert trainer.strategy == strategy + + +@RunIf(hivemind=True) +@mock.patch("hivemind.DHT", autospec=True) +@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True) +@pytest.mark.parametrize( + "initial_peers,peer_endpoint", + [(["TEST"], None), (None, "localhost:153")], +) +def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, peer_endpoint): + """Test when we are a second peer (passing initial peers or peer endpoint) we warn the user that + logging/checkpointing will be disabled.""" + with pytest.warns(UserWarning, match="This machine is not a persistent machine"): + CollaborativeStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint) + + +@RunIf(hivemind=True) +@mock.patch.dict( + os.environ, + {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", "PL_PORT": str(find_free_network_port())}, + clear=True, +) +@pytest.mark.parametrize( + "endpoint,expected_message", + [(False, "INITIAL_PEERS"), (True, "Sidecar endpoint enabled to serve peers.")], +) +def test_initial_peer_message(caplog, endpoint, expected_message): + model = BoringModel() + trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True) + trainer.fit(model) + assert expected_message in caplog.text + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +def test_optimizer_wrapped(): + class TestModel(BoringModel): + def on_before_backward(self, loss: torch.Tensor) -> None: + optimizer = self.trainer.optimizers[0] + assert isinstance(optimizer, hivemind.Optimizer) + + model = TestModel() + trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + trainer.fit(model) + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +def test_scheduler_wrapped(): + class TestModel(BoringModel): + def on_before_backward(self, loss: torch.Tensor) -> None: + scheduler = self.trainer.lr_scheduler_configs[0].scheduler + assert isinstance(scheduler, HiveMindScheduler) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return [optimizer], [torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)] + + model = TestModel() + trainer = pl.Trainer( + strategy=CollaborativeStrategy(target_batch_size=1), + fast_dev_run=True, + ) + trainer.fit(model) + + +@RunIf(hivemind=True) +@mock.patch.dict( + os.environ, + { + "HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", + "PL_INITIAL_PEERS": "TEST_PEERS", + "PL_HOST": "TEST_HOST", + "PL_PORT": "1300", + "PL_ENDPOINT": "1", + "PL_PEER_ENDPOINT": "TEST_PEER_ENDPOINT", + }, + clear=True, +) +@mock.patch("hivemind.DHT", autospec=True) +@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True) +@mock.patch("http.server.ThreadingHTTPServer", autospec=True) +def test_env_variables_parsed(mock_dht, mock_peers, mock_server): + """Test that env variables are parsed correctly.""" + strategy = CollaborativeStrategy(target_batch_size=1) + assert strategy.dht_manager._initial_peers == ["TEST_PEERS"] + assert strategy.dht_manager._host == "TEST_HOST" + assert strategy.dht_manager._port == 1300 + assert strategy.dht_manager._endpoint + assert strategy.dht_manager._peer_endpoint == "TEST_PEER_ENDPOINT" + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +def test_reuse_grad_buffers_warning(): + """Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True.""" + + class TestModel(BoringModel): + def on_before_backward(self, loss: torch.Tensor) -> None: + optimizer = self.trainer.optimizers[0] + assert isinstance(optimizer, hivemind.Optimizer) + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + pass + + model = TestModel() + trainer = pl.Trainer( + strategy=CollaborativeStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True + ) + + with pytest.warns(UserWarning, match="You have overridden `optimizer_zero_grad` which will be disabled."): + trainer.fit(model) + + +@RunIf(hivemind=True) +def test_raise_exception_multiple_optimizers(): + """Test that we raise an exception when multiple optimizers are provided.""" + + class TestModel(BoringModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer, optimizer], [lr_scheduler] + + model = TestModel() + trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + + with pytest.raises(MisconfigurationException, match="Hivemind only supports training with one optimizer."): + trainer.fit(model) + + +@RunIf(hivemind=True) +@mock.patch("pytorch_lightning.utilities.data._extract_batch_size", autospec=True, return_value=[None]) +def test_raise_exception_no_batch_size(mock_extract_batch_size): + """Test that we raise an exception when no batch size is automatically found.""" + + model = BoringModel() + trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True) + + with pytest.raises(MisconfigurationException, match="Please provide the batch size to the Strategy."): + trainer.fit(model) + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +@pytest.mark.parametrize( + "delay_grad_averaging, delay_state_averaging, delay_optimizer_step", + [(True, True, True), (False, True, False)], +) +def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, delay_optimizer_step): + """Test ensures that valid combination of HiveMind delay arguments warn if scheduler isn't passed in as a + function.""" + model = BoringModel() + trainer = pl.Trainer( + strategy=CollaborativeStrategy( + target_batch_size=1, + delay_grad_averaging=delay_grad_averaging, + delay_state_averaging=delay_state_averaging, + delay_optimizer_step=delay_optimizer_step, + ), + fast_dev_run=True, + ) + + with pytest.warns(UserWarning, match="requires a `scheduler_fn` to be passed to the strategy"): + trainer.fit(model) + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +@mock.patch("http.server.ThreadingHTTPServer", autospec=True) +@mock.patch("pytorch_lightning.strategies.collaborative.CollaborativeStrategy.num_peers", new_callable=PropertyMock) +def test_args_passed_to_optimizer(mock_peers, mock_server): + """Test to ensure arguments are correctly passed to the hivemind optimizer wrapper.""" + mock_peers.return_value = 1 + compression = hivemind.ScaledFloat16Compression() + with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer: + + class TestModel(BoringModel): + def on_before_backward(self, loss: torch.Tensor) -> None: + args, kwargs = mock_optimizer.call_args + mock_optimizer.assert_called() + arguments = dict( + delay_optimizer_step=True, + delay_state_averaging=True, + state_averaging_compression=compression, + grad_compression=compression, + offload_optimizer=True, + reuse_grad_buffers=True, + target_batch_size=1, + ) + + for key, value in arguments.items(): + assert key in kwargs + assert value == kwargs[key] + + model = TestModel() + trainer = pl.Trainer( + strategy=CollaborativeStrategy( + target_batch_size=1, + reuse_grad_buffers=True, + delay_state_averaging=True, + delay_optimizer_step=True, + offload_optimizer=True, + grad_compression=compression, + state_averaging_compression=compression, + ), + fast_dev_run=True, + ) + trainer.fit(model) + # ensures that after training with `reuse_grad_buffers` we restore the hook + assert model.optimizer_zero_grad is not None + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +@pytest.mark.parametrize( + "host_maddrs,expected_maddrs", + [(None, ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"]), (["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"])], +) +def test_maddrs(host_maddrs, expected_maddrs): + """Test that the multiple addresses are correctly assigned.""" + strategy = CollaborativeStrategy(target_batch_size=1, host_maddrs=host_maddrs) + assert strategy.dht.kwargs["host_maddrs"] == expected_maddrs + + +def _run_collab_training_fn(initial_peers, wait_seconds, barrier, recorded_process_peers, recorded_process_steps): + recorded_peers = [] + recorded_global_steps = [] + + class TestModel(BoringModel): + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None: + time.sleep(wait_seconds) # add an additional delay to give processes time to sync + recorded_peers.append(self.trainer.strategy.num_peers) + recorded_global_steps.append(self.trainer.optimizers[0].local_epoch) + + def on_train_end(self) -> None: + # wait for all processes to get to the end of training before teardown + barrier.wait() + + model = TestModel() + trainer = pl.Trainer( + max_epochs=1, + limit_train_batches=16, + limit_val_batches=0, + strategy=CollaborativeStrategy( + delay_state_averaging=True, + offload_optimizer=True, + delay_optimizer_step=True, + delay_grad_averaging=True, + target_batch_size=8, + initial_peers=initial_peers, + verbose=False, + ), + ) + trainer.fit(model) + + recorded_process_peers.append(recorded_peers) + recorded_process_steps.append(recorded_global_steps) + + +@RunIf(hivemind=True) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +@pytest.mark.parametrize( + "num_processes, wait_seconds", + [(2, 0.25)], +) +def test_multiple_peers(num_processes, wait_seconds): + """Test to ensure that if we have two running processes with the same peers, they connect and train + successfully.""" + dht_root = hivemind.DHT(start=True) + barrier = mp.Barrier(num_processes) + initial_peers = dht_root.get_visible_maddrs() + + with mp.Manager() as manager: + # allows processes to return their recorded logged peers/steps + recorded_process_peers = manager.list() + recorded_process_steps = manager.list() + processes = [ + mp.Process( + target=_run_collab_training_fn, + kwargs=dict( + initial_peers=initial_peers, + wait_seconds=wait_seconds, + barrier=barrier, + recorded_process_peers=recorded_process_peers, + recorded_process_steps=recorded_process_steps, + ), + ) + for x in range(num_processes) + ] + for process in processes: + process.start() + for process in processes: + process.join() + # assert that peers increase as expected and we run at-least 1 global step. + for process_peers, process_steps in zip(recorded_process_peers, recorded_process_steps): + assert any(num_peer == num_processes for num_peer in process_peers) + assert any(global_step > 0 for global_step in process_steps) + + +@RunIf(hivemind=True, min_gpus=1) +@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True) +def test_scaler_updated_precision_16(): + class TestModel(BoringModel): + def on_fit_start(self) -> None: + assert isinstance(self.trainer.precision_plugin.scaler, hivemind.GradScaler) + raise SystemExit + + model = TestModel() + trainer = pl.Trainer( + strategy=CollaborativeStrategy(target_batch_size=1), + fast_dev_run=True, + precision=16, + accelerator="gpu", + devices=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(hivemind=True) +def test_raise_when_peer_endpoint_unsuccessful(caplog): + port = find_free_network_port() + with pytest.raises(MisconfigurationException, match="Unable to get peers"): + with mock.patch("requests.get", wraps=requests.get) as requests_mock: + CollaborativeStrategy( + target_batch_size=1, + peer_endpoint=f"localhost:{port}", + retry_endpoint_attempts=10, + retry_endpoint_sleep_duration=0, + ) + assert "Failed to get peers, retrying" in caplog.text + assert requests_mock.call_count == 10