[1/2] Collaborative Strategy (#12842)

This commit is contained in:
Sean Naren 2022-05-05 17:06:26 +01:00 committed by GitHub
parent d337374da7
commit 1a502c061c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 921 additions and 0 deletions

View File

@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401

View File

@ -0,0 +1,529 @@
import http
import ipaddress
import logging
import os
import platform
import re
import threading
import time
import warnings
from http.server import BaseHTTPRequestHandler
from typing import Any, Callable, Dict, List, Optional, Union
import requests
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _LRScheduler, ReduceLROnPlateau
if _HIVEMIND_AVAILABLE:
import hivemind
log = logging.getLogger(__name__)
class CollaborativeStrategy(Strategy):
def __init__(
self,
target_batch_size: int,
run_id: str = "lightning_run",
batch_size: Optional[int] = None,
delay_state_averaging: bool = False,
delay_optimizer_step: Optional[bool] = None,
delay_grad_averaging: bool = False,
offload_optimizer: Optional[bool] = None,
reuse_grad_buffers: bool = False,
scheduler_fn: Optional[Callable] = None,
matchmaking_time: float = 5.0,
averaging_timeout: float = 30.0,
verbose: bool = False,
averager_opts: Optional[Dict] = None,
host_maddrs: Optional[List] = None,
initial_peers: Optional[Union[str, List]] = None,
endpoint: Optional[bool] = None,
peer_endpoint: Optional[str] = None,
persistent: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
retry_endpoint_attempts: int = 5,
retry_endpoint_sleep_duration: int = 5,
**optimizer_kwargs: Any,
):
"""Provides capabilities to train using the Hivemind Library, training collaboratively across the internet
with unreliable machines. For more information, `refer to the docs <https://pytorch-
lightning.readthedocs.io/en/latest/strategies/collaborative_training.html>`__.
.. warning:: ``CollaborativeStrategy`` is experimental and subject to change.
Arguments:
target_batch_size: When training, the batch size to accumulate to before running a step. The larger this
batch size, the more work can be done asynchronously without communication.
run_id: A unique identifier of this training run, used as a common prefix for all DHT keys.
See ``https://learning-at-home.readthedocs.io/en/latest/user/dht.html``.
batch_size: The local batch size per process. If not provided, we infer this from the first batch of data
passed in at training (lazy). Note that this should not change throughout training.
delay_state_averaging: If enabled (default), average parameters and extra tensors in a background thread;
if set to False, average parameters synchronously within the
corresponding :meth:`hivemind.Optimizer.step` call.
delay_optimizer_step: Run optimizer in background, apply results in future .step. requires
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer`.
delay_grad_averaging: Average gradients in background; requires
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer` and
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.delay_optimizer_step`.
offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients.
reuse_grad_buffers: Use the model's gradient buffers (params.grad) for gradient accumulation
which is more memory efficient. Lightning will automatically disable ``zero_grad``
in the ``LightningModule``.
scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn``
is required to be passed to the ``CollaborativeStrategy``. This is because the optimizer
is re-created and the scheduler needs to be re-created as well.
matchmaking_time: When looking for group, wait for peers to join for up to this many seconds.
Increase if you see "averaged gradients with N peers" where N is below 0.9x on >=25% of epochs.
Training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
averaging_timeout: If an averaging step hangs for this long, it will be cancelled automatically.
Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
Do not set this timeout too high, as it may cause your optimizer to hang
after some types of network errors.
verbose: Report internal Hivemind events such as accumulating gradients and running background tasks.
averager_opts: Additional keyword arguments forwarded to both
``GradientAverager`` and ``TrainingStateAverager``.
host_maddrs: List of multi-addrs to create visible peers for other processes.
`https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet`
initial_peers: If connecting to a running process, a list of initial peers needs to be passed in.
This can also be set via the env variable ``INITIAL_PEERS``.
endpoint: Enable if a side-car endpoint server is required on the process to server initial peers.
This is useful when using some form of orchestration such as torchelastic.
peer_endpoint: The endpoint to request initial peers from.
persistent: When using an endpoint, this controls whether other processes that are not the endpoint
server log/checkpoint. If ``persistent`` is True, we do not log/checkpoint from other processes.
host: When creating the endpoint, the host IP to use.
port: When creating the endpoint, the host port to use.
retry_endpoint_attempts: When connecting to the
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`,
how many time to retry before raising an exception.
retry_endpoint_sleep_duration: When connecting to the
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`,
how long to wait between retries.
**optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class.
"""
if not _HIVEMIND_AVAILABLE or platform.system() != "Linux":
raise MisconfigurationException(
"To use the `CollaborativeStrategy`, you must have Hivemind installed and be running on Linux."
" Install it by running `pip install -U hivemind`."
)
super().__init__()
self.dht_manager = DHTManager(
persistent=persistent,
endpoint=endpoint,
peer_endpoint=peer_endpoint,
host=host,
port=port,
host_maddrs=host_maddrs,
initial_peers=initial_peers,
retry_endpoint_attempts=retry_endpoint_attempts,
retry_endpoint_sleep_duration=retry_endpoint_sleep_duration,
)
self._target_batch_size = target_batch_size
self._batch_size = batch_size
self._scheduler_fn = scheduler_fn
self._require_scheduler_fn = delay_optimizer_step or delay_state_averaging or offload_optimizer
self._opt = None
self._optimizer_zero_grad_original: Optional[Callable] = None
self._run_id = run_id
self._reuse_grad_buffers = reuse_grad_buffers
self._optimizer_kwargs = dict(
matchmaking_time=matchmaking_time,
averaging_timeout=averaging_timeout,
delay_optimizer_step=delay_optimizer_step,
delay_state_averaging=delay_state_averaging,
delay_grad_averaging=delay_grad_averaging,
offload_optimizer=offload_optimizer,
averager_opts=averager_opts if averaging_timeout is not None else dict(request_timeout=1.0),
verbose=verbose,
reuse_grad_buffers=reuse_grad_buffers,
**optimizer_kwargs,
)
# a bit of a hack to only log from the stable server
if self.dht_manager.disable_logging_checkpointing:
warnings.warn(
"This machine is not a persistent machine. Checkpointing/Logging has been disabled.", UserWarning
)
rank_zero_only.rank = 1 if self.dht_manager.disable_logging_checkpointing else 0
self._hivemind_initialized = False
@property
def num_peers(self) -> int:
if self._opt:
return self._opt.tracker.global_progress.num_peers
return 1
@property
def dht(self) -> "hivemind.DHT":
"""Hivemind Distributed Hash Table which stores values across all peers.
See documentation for more details: `https://learning-at-home.readthedocs.io/en/latest/modules/dht.html`
"""
return self.dht_manager.dht
@property
def root_device(self) -> torch.device:
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
if isinstance(self.accelerator, GPUAccelerator):
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif isinstance(self.accelerator, CPUAccelerator):
return torch.device("cpu")
raise MisconfigurationException(
f"Was unable to infer device type from the accelerator: {self.accelerator.__class__.__name__}."
)
@property
def global_rank(self) -> int:
return 0
@property
def is_global_zero(self) -> bool:
return True
def setup(self, trainer: "pl.Trainer") -> None:
self.model_to_device()
super().setup(trainer)
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
self.precision_plugin.scaler = hivemind.GradScaler()
def _initialize_hivemind(self) -> None:
if len(self.optimizers) > 1:
raise MisconfigurationException("Hivemind only supports training with one optimizer.")
optimizer = self.optimizers[0]
if self._require_scheduler_fn and self._scheduler_fn is None:
rank_zero_warn(
"Enabling `delay_optimizer_step`, `delay_state_averaging` or `offload_optimizer` "
"requires a `scheduler_fn` to be passed to the strategy if a scheduler is being used "
"(this is because the optimizer is re-created within Hivemind)."
)
scheduler = self._scheduler_fn if self._require_scheduler_fn else None
params = optimizer.param_groups if self._require_scheduler_fn else None
optimizer = type(optimizer) if self._require_scheduler_fn else optimizer
opt = hivemind.Optimizer(
dht=self.dht,
run_id=self._run_id,
params=params,
optimizer=optimizer,
scheduler=scheduler,
target_batch_size=self._target_batch_size,
batch_size_per_step=self._batch_size,
**self._optimizer_kwargs,
)
if not self._scheduler_fn:
self._wrap_schedulers(opt)
opt.load_state_from_peers()
self.optimizers = [opt]
self._opt = opt
if self._reuse_grad_buffers:
assert self.lightning_module is not None
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
self._disable_zero_grad()
def _disable_zero_grad(self) -> None:
lightning_module = self.lightning_module
if is_overridden("optimizer_zero_grad", lightning_module):
assert lightning_module is not None # `is_overridden` returns False otherwise
rank_zero_warn(
"You have overridden `optimizer_zero_grad` which will be disabled."
" When `CollaborativeStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad,"
" as this would delete the gradients before they are averaged."
)
assert lightning_module is not None
lightning_module.optimizer_zero_grad = None # type: ignore[assignment]
def _wrap_schedulers(self, opt: "hivemind.Optimizer") -> None:
# wrap schedulers so that they only update when the hivemind optimizer updates
for scheduler_config in self.lr_scheduler_configs:
scheduler = scheduler_config.scheduler
if isinstance(scheduler, ReduceLROnPlateau):
raise ValueError(
f"The `ReduceLROnPlateau` scheduler is not currently supported with `{self.__class__.__name__}`."
)
scheduler_config.scheduler = HiveMindScheduler(
optimizer=opt,
scheduler=scheduler,
)
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
if not self._hivemind_initialized:
self._hivemind_initialized = True
# todo (sean): we could technically support a dynamic batch size by inferring each step
# and passing it to the ``hivemind.Optimizer``.
if self._batch_size is None:
try:
self._batch_size = extract_batch_size(batch)
log.info(f"Found per machine batch size automatically from the batch: {self._batch_size}")
except (MisconfigurationException, RecursionError) as e:
raise MisconfigurationException(
"We tried to infer the batch size from the first batch of data. "
"Please provide the batch size to the Strategy by "
"``Trainer(strategy=CollaborativeStrategy(batch_size=x))``. "
) from e
self._initialize_hivemind()
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
return tensor
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
return tensor
def model_to_device(self) -> None:
assert self.model is not None
self.model.to(self.root_device)
def barrier(self, *args: Any, **kwargs: Any) -> None:
pass
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj
def teardown(self) -> None:
super().teardown()
if self._optimizer_zero_grad_original is not None and self.lightning_module is not None:
# re-enable `optimizer_zero_grad`
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment]
if self.root_device.type == "cuda" and self.lightning_module is not None:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
if self._opt:
self._opt.shutdown()
log.info("Shutting down hivemind DHT.")
self.dht.shutdown()
class HiveMindScheduler:
"""Wrapper for schedulers to prevent Lightning from stepping the scheduler too soon.
This code ensures that we only step when the HiveMind optimizer reaches the global step.
"""
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}
self.optimizer = optimizer
self.scheduler = scheduler
self.current_step = -1
def step(self, epoch: Optional[int] = None) -> None:
while self.current_step < self.optimizer.local_epoch:
self.scheduler.step(epoch=epoch)
self.current_step += 1
def load_state_dict(self, state_dict: Dict) -> None:
self.scheduler.load_state_dict(state_dict)
def state_dict(self) -> Dict:
return self.scheduler.state_dict()
class DHTManager:
ENDPOINT_ENV: str = "PL_ENDPOINT"
PEER_ENDPOINT_ENV: str = "PL_PEER_ENDPOINT"
INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS"
HOST_ENV: str = "PL_HOST"
PORT_ENV: str = "PL_PORT"
DEFAULT_HOST: str = "0.0.0.0"
DEFAULT_PORT: int = 1440
def __init__(
self,
host_maddrs: Optional[List],
initial_peers: Optional[Union[str, List]],
persistent: bool,
endpoint: Optional[bool],
peer_endpoint: Optional[str],
host: Optional[str],
port: Optional[int],
retry_endpoint_attempts: int = 5,
retry_endpoint_sleep_duration: int = 5,
) -> None:
"""Manages the `hivemind.DHT` connection and provides a side-car endpoint server for initial peer access.
Arguments:
host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host_maddrs`
initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.initial_peers`
persistent: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.persistent`
endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.endpoint`
peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`
host: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host`
port: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.port`
retry_endpoint_attempts:
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_attempts`
retry_endpoint_sleep_duration:
:paramref:
`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_sleep_duration`
"""
self._persistent = persistent
self._endpoint = endpoint
self._initial_peers = initial_peers
self._peer_endpoint = peer_endpoint
self._host = host
self._port = port
self._parse_env_vars()
if self._peer_endpoint and self._initial_peers is None:
self._initial_peers = self._get_initial_peers_from_endpoint(
retry_initial_peers=retry_endpoint_attempts, retry_peer_sleep_duration=retry_endpoint_sleep_duration
)
self.dht = hivemind.DHT(
start=True,
initial_peers=self._initial_peers,
host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
)
visible_addresses = [
str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
]
if self._endpoint:
self._host = self._host if self._host is not None else self.DEFAULT_HOST
self._port = self._port if self._port is not None else self.DEFAULT_PORT
self._start_server_process(self._host, self._port)
self._log_endpoint_helper_message(visible_addresses)
elif self._peer_endpoint:
log.info("Machine received initial peers from endpoint.")
elif self._initial_peers is None:
log.info(
"\nOther machines can connect running the same command:\n"
f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n"
"or passing the peers to the strategy:\n"
f"CollaborativeStrategy(initial_peers='{','.join(visible_addresses)}')"
)
def _log_endpoint_helper_message(self, visible_addresses: List[str]) -> None:
assert self._host is not None
resolved_host = self._host
if "0.0.0.0" in self._host:
# use the visible multi-addresses to figure out the IP that has been exposed
# todo (sean): this is pretty hacky, worth investigating.
p = re.compile(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
# todo (sean): we select one address from here, could we have multiple?
resolved_host = {p.findall(maddr)[0] for maddr in visible_addresses}.pop()
log.info(
"\nSidecar endpoint enabled to serve peers.\n"
"Other peers can connect via:\n"
f"PEER_ENDPOINT={resolved_host}:{self._port} python ...\n"
"or pass the peer endpoint address to the strategy:\n"
f"CollaborativeStrategy(peer_endpoint='{resolved_host}:{self._port}')"
)
def _start_server_process(self, host: str, port: int) -> None:
dht = self.dht
class DHTHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
"""Respond to a GET request."""
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
visible_peers = [
str(a) for a in dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
]
self.wfile.write("\n".join(visible_peers).encode())
server = http.server.ThreadingHTTPServer((host, int(port)), DHTHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
def _get_initial_peers_from_endpoint(self, retry_initial_peers: int, retry_peer_sleep_duration: int) -> List:
peers = None
for _ in range(retry_initial_peers):
try:
peers = self._get_peers()
break
except requests.exceptions.RequestException:
log.info(f"Failed to get peers, retrying in {retry_peer_sleep_duration} seconds...")
time.sleep(retry_peer_sleep_duration)
if peers is None:
raise MisconfigurationException(
f"Unable to get peers. Tried {retry_initial_peers} times waiting {retry_peer_sleep_duration}s."
f"These parameters can be extended by passing "
"to the strategy (CollaborativeStrategy(retry_connection=x, retry_sleep_duration=y))."
)
log.info(f"Received initial peers from collaborative server: {peers}")
return peers
def _get_peers(self) -> List[str]:
assert self._peer_endpoint is not None
url = f"http://{self._peer_endpoint}" if not self._peer_endpoint.startswith("http://") else self._peer_endpoint
r = requests.get(url)
return r.text.split(",")
def _parse_env_vars(self) -> None:
endpoint = os.environ.get(self.ENDPOINT_ENV, self._endpoint)
self._endpoint = endpoint == "1" if isinstance(endpoint, str) else endpoint
self._peer_endpoint = os.environ.get(self.PEER_ENDPOINT_ENV, self._peer_endpoint)
initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers)
self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else initial_peers
port = os.environ.get(self.PORT_ENV, self._port)
self._port = int(port) if isinstance(port, str) else port
self._host = os.environ.get(self.HOST_ENV, self._host)
@property
def disable_logging_checkpointing(self) -> bool:
# if this node is a peer, we do not log/checkpoint in persistent mode.
return self._persistent and (self._initial_peers is not None or self._peer_endpoint is not None)

View File

@ -35,6 +35,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
_GROUP_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_HYDRA_AVAILABLE,

View File

@ -105,6 +105,7 @@ _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group")
_HIVEMIND_AVAILABLE = _package_available("hivemind")
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _package_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")

View File

@ -68,6 +68,9 @@ class _LRScheduler(_Stateful, Protocol):
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...
def step(self, epoch: Optional[int] = None) -> None:
...
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@ -91,6 +94,9 @@ class ReduceLROnPlateau(_Stateful, Protocol):
) -> None:
...
def step(self, metrics: Union[float, int, torch.Tensor], epoch: Optional[int] = None) -> None:
...
# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)

View File

@ -1,3 +1,4 @@
fairscale>=0.4.5
deepspeed<0.6.0
horovod>=0.21.2,!=0.24.0 # no need to install with [pytorch] as pytorch is already installed
hivemind>=1.0.1; sys_platform == 'linux'

View File

@ -26,6 +26,7 @@ from pytorch_lightning.utilities import (
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HIVEMIND_AVAILABLE,
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
@ -84,6 +85,7 @@ class RunIf:
omegaconf: bool = False,
slow: bool = False,
bagua: bool = False,
hivemind: bool = False,
**kwargs,
):
"""
@ -111,6 +113,7 @@ class RunIf:
omegaconf: Require that omry/omegaconf is installed.
slow: Mark the test as slow, our CI will run it in a separate job.
bagua: Require that BaguaSys/bagua is installed.
hivemind: Require that Hivemind is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
@ -231,6 +234,10 @@ class RunIf:
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Bagua")
if hivemind:
conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Hivemind")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs

View File

@ -0,0 +1,372 @@
import multiprocessing as mp
import os
import time
from typing import Any
from unittest import mock
from unittest.mock import PropertyMock
import pytest
import requests
import torch
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
from pytorch_lightning.strategies import CollaborativeStrategy
from pytorch_lightning.strategies.collaborative import HiveMindScheduler
from pytorch_lightning.utilities import _HIVEMIND_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if _HIVEMIND_AVAILABLE:
import hivemind
@mock.patch("pytorch_lightning.strategies.collaborative._HIVEMIND_AVAILABLE", False)
def test_raise_exception_if_hivemind_unavailable():
"""Test that we raise an exception when Hivemind is not available."""
with pytest.raises(MisconfigurationException, match="you must have Hivemind installed"):
CollaborativeStrategy(target_batch_size=1)
@RunIf(hivemind=True)
@mock.patch("hivemind.DHT", autospec=True)
def test_strategy(mock_dht):
strategy = CollaborativeStrategy(target_batch_size=1)
trainer = pl.Trainer(strategy=strategy)
assert trainer.strategy == strategy
@RunIf(hivemind=True)
@mock.patch("hivemind.DHT", autospec=True)
@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True)
@pytest.mark.parametrize(
"initial_peers,peer_endpoint",
[(["TEST"], None), (None, "localhost:153")],
)
def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, peer_endpoint):
"""Test when we are a second peer (passing initial peers or peer endpoint) we warn the user that
logging/checkpointing will be disabled."""
with pytest.warns(UserWarning, match="This machine is not a persistent machine"):
CollaborativeStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint)
@RunIf(hivemind=True)
@mock.patch.dict(
os.environ,
{"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", "PL_PORT": str(find_free_network_port())},
clear=True,
)
@pytest.mark.parametrize(
"endpoint,expected_message",
[(False, "INITIAL_PEERS"), (True, "Sidecar endpoint enabled to serve peers.")],
)
def test_initial_peer_message(caplog, endpoint, expected_message):
model = BoringModel()
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True)
trainer.fit(model)
assert expected_message in caplog.text
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_optimizer_wrapped():
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
optimizer = self.trainer.optimizers[0]
assert isinstance(optimizer, hivemind.Optimizer)
model = TestModel()
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
trainer.fit(model)
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_scheduler_wrapped():
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
scheduler = self.trainer.lr_scheduler_configs[0].scheduler
assert isinstance(scheduler, HiveMindScheduler)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return [optimizer], [torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)]
model = TestModel()
trainer = pl.Trainer(
strategy=CollaborativeStrategy(target_batch_size=1),
fast_dev_run=True,
)
trainer.fit(model)
@RunIf(hivemind=True)
@mock.patch.dict(
os.environ,
{
"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor",
"PL_INITIAL_PEERS": "TEST_PEERS",
"PL_HOST": "TEST_HOST",
"PL_PORT": "1300",
"PL_ENDPOINT": "1",
"PL_PEER_ENDPOINT": "TEST_PEER_ENDPOINT",
},
clear=True,
)
@mock.patch("hivemind.DHT", autospec=True)
@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True)
@mock.patch("http.server.ThreadingHTTPServer", autospec=True)
def test_env_variables_parsed(mock_dht, mock_peers, mock_server):
"""Test that env variables are parsed correctly."""
strategy = CollaborativeStrategy(target_batch_size=1)
assert strategy.dht_manager._initial_peers == ["TEST_PEERS"]
assert strategy.dht_manager._host == "TEST_HOST"
assert strategy.dht_manager._port == 1300
assert strategy.dht_manager._endpoint
assert strategy.dht_manager._peer_endpoint == "TEST_PEER_ENDPOINT"
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_reuse_grad_buffers_warning():
"""Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True."""
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
optimizer = self.trainer.optimizers[0]
assert isinstance(optimizer, hivemind.Optimizer)
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
pass
model = TestModel()
trainer = pl.Trainer(
strategy=CollaborativeStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True
)
with pytest.warns(UserWarning, match="You have overridden `optimizer_zero_grad` which will be disabled."):
trainer.fit(model)
@RunIf(hivemind=True)
def test_raise_exception_multiple_optimizers():
"""Test that we raise an exception when multiple optimizers are provided."""
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer, optimizer], [lr_scheduler]
model = TestModel()
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="Hivemind only supports training with one optimizer."):
trainer.fit(model)
@RunIf(hivemind=True)
@mock.patch("pytorch_lightning.utilities.data._extract_batch_size", autospec=True, return_value=[None])
def test_raise_exception_no_batch_size(mock_extract_batch_size):
"""Test that we raise an exception when no batch size is automatically found."""
model = BoringModel()
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="Please provide the batch size to the Strategy."):
trainer.fit(model)
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
@pytest.mark.parametrize(
"delay_grad_averaging, delay_state_averaging, delay_optimizer_step",
[(True, True, True), (False, True, False)],
)
def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, delay_optimizer_step):
"""Test ensures that valid combination of HiveMind delay arguments warn if scheduler isn't passed in as a
function."""
model = BoringModel()
trainer = pl.Trainer(
strategy=CollaborativeStrategy(
target_batch_size=1,
delay_grad_averaging=delay_grad_averaging,
delay_state_averaging=delay_state_averaging,
delay_optimizer_step=delay_optimizer_step,
),
fast_dev_run=True,
)
with pytest.warns(UserWarning, match="requires a `scheduler_fn` to be passed to the strategy"):
trainer.fit(model)
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
@mock.patch("http.server.ThreadingHTTPServer", autospec=True)
@mock.patch("pytorch_lightning.strategies.collaborative.CollaborativeStrategy.num_peers", new_callable=PropertyMock)
def test_args_passed_to_optimizer(mock_peers, mock_server):
"""Test to ensure arguments are correctly passed to the hivemind optimizer wrapper."""
mock_peers.return_value = 1
compression = hivemind.ScaledFloat16Compression()
with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer:
class TestModel(BoringModel):
def on_before_backward(self, loss: torch.Tensor) -> None:
args, kwargs = mock_optimizer.call_args
mock_optimizer.assert_called()
arguments = dict(
delay_optimizer_step=True,
delay_state_averaging=True,
state_averaging_compression=compression,
grad_compression=compression,
offload_optimizer=True,
reuse_grad_buffers=True,
target_batch_size=1,
)
for key, value in arguments.items():
assert key in kwargs
assert value == kwargs[key]
model = TestModel()
trainer = pl.Trainer(
strategy=CollaborativeStrategy(
target_batch_size=1,
reuse_grad_buffers=True,
delay_state_averaging=True,
delay_optimizer_step=True,
offload_optimizer=True,
grad_compression=compression,
state_averaging_compression=compression,
),
fast_dev_run=True,
)
trainer.fit(model)
# ensures that after training with `reuse_grad_buffers` we restore the hook
assert model.optimizer_zero_grad is not None
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
@pytest.mark.parametrize(
"host_maddrs,expected_maddrs",
[(None, ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"]), (["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"])],
)
def test_maddrs(host_maddrs, expected_maddrs):
"""Test that the multiple addresses are correctly assigned."""
strategy = CollaborativeStrategy(target_batch_size=1, host_maddrs=host_maddrs)
assert strategy.dht.kwargs["host_maddrs"] == expected_maddrs
def _run_collab_training_fn(initial_peers, wait_seconds, barrier, recorded_process_peers, recorded_process_steps):
recorded_peers = []
recorded_global_steps = []
class TestModel(BoringModel):
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None:
time.sleep(wait_seconds) # add an additional delay to give processes time to sync
recorded_peers.append(self.trainer.strategy.num_peers)
recorded_global_steps.append(self.trainer.optimizers[0].local_epoch)
def on_train_end(self) -> None:
# wait for all processes to get to the end of training before teardown
barrier.wait()
model = TestModel()
trainer = pl.Trainer(
max_epochs=1,
limit_train_batches=16,
limit_val_batches=0,
strategy=CollaborativeStrategy(
delay_state_averaging=True,
offload_optimizer=True,
delay_optimizer_step=True,
delay_grad_averaging=True,
target_batch_size=8,
initial_peers=initial_peers,
verbose=False,
),
)
trainer.fit(model)
recorded_process_peers.append(recorded_peers)
recorded_process_steps.append(recorded_global_steps)
@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
@pytest.mark.parametrize(
"num_processes, wait_seconds",
[(2, 0.25)],
)
def test_multiple_peers(num_processes, wait_seconds):
"""Test to ensure that if we have two running processes with the same peers, they connect and train
successfully."""
dht_root = hivemind.DHT(start=True)
barrier = mp.Barrier(num_processes)
initial_peers = dht_root.get_visible_maddrs()
with mp.Manager() as manager:
# allows processes to return their recorded logged peers/steps
recorded_process_peers = manager.list()
recorded_process_steps = manager.list()
processes = [
mp.Process(
target=_run_collab_training_fn,
kwargs=dict(
initial_peers=initial_peers,
wait_seconds=wait_seconds,
barrier=barrier,
recorded_process_peers=recorded_process_peers,
recorded_process_steps=recorded_process_steps,
),
)
for x in range(num_processes)
]
for process in processes:
process.start()
for process in processes:
process.join()
# assert that peers increase as expected and we run at-least 1 global step.
for process_peers, process_steps in zip(recorded_process_peers, recorded_process_steps):
assert any(num_peer == num_processes for num_peer in process_peers)
assert any(global_step > 0 for global_step in process_steps)
@RunIf(hivemind=True, min_gpus=1)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
def test_scaler_updated_precision_16():
class TestModel(BoringModel):
def on_fit_start(self) -> None:
assert isinstance(self.trainer.precision_plugin.scaler, hivemind.GradScaler)
raise SystemExit
model = TestModel()
trainer = pl.Trainer(
strategy=CollaborativeStrategy(target_batch_size=1),
fast_dev_run=True,
precision=16,
accelerator="gpu",
devices=1,
)
with pytest.raises(SystemExit):
trainer.fit(model)
@RunIf(hivemind=True)
def test_raise_when_peer_endpoint_unsuccessful(caplog):
port = find_free_network_port()
with pytest.raises(MisconfigurationException, match="Unable to get peers"):
with mock.patch("requests.get", wraps=requests.get) as requests_mock:
CollaborativeStrategy(
target_batch_size=1,
peer_endpoint=f"localhost:{port}",
retry_endpoint_attempts=10,
retry_endpoint_sleep_duration=0,
)
assert "Failed to get peers, retrying" in caplog.text
assert requests_mock.call_count == 10