[1/2] Collaborative Strategy (#12842)
This commit is contained in:
parent
d337374da7
commit
1a502c061c
|
@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
|
||||
|
||||
|
||||
- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))
|
||||
|
||||
|
||||
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
|
||||
from pytorch_lightning.strategies.collaborative import CollaborativeStrategy # noqa: F401
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
|
||||
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
|
||||
|
|
|
@ -0,0 +1,529 @@
|
|||
import http
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.data import extract_batch_size
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HIVEMIND_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import _LRScheduler, ReduceLROnPlateau
|
||||
|
||||
if _HIVEMIND_AVAILABLE:
|
||||
import hivemind
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CollaborativeStrategy(Strategy):
|
||||
def __init__(
|
||||
self,
|
||||
target_batch_size: int,
|
||||
run_id: str = "lightning_run",
|
||||
batch_size: Optional[int] = None,
|
||||
delay_state_averaging: bool = False,
|
||||
delay_optimizer_step: Optional[bool] = None,
|
||||
delay_grad_averaging: bool = False,
|
||||
offload_optimizer: Optional[bool] = None,
|
||||
reuse_grad_buffers: bool = False,
|
||||
scheduler_fn: Optional[Callable] = None,
|
||||
matchmaking_time: float = 5.0,
|
||||
averaging_timeout: float = 30.0,
|
||||
verbose: bool = False,
|
||||
averager_opts: Optional[Dict] = None,
|
||||
host_maddrs: Optional[List] = None,
|
||||
initial_peers: Optional[Union[str, List]] = None,
|
||||
endpoint: Optional[bool] = None,
|
||||
peer_endpoint: Optional[str] = None,
|
||||
persistent: bool = True,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
retry_endpoint_attempts: int = 5,
|
||||
retry_endpoint_sleep_duration: int = 5,
|
||||
**optimizer_kwargs: Any,
|
||||
):
|
||||
"""Provides capabilities to train using the Hivemind Library, training collaboratively across the internet
|
||||
with unreliable machines. For more information, `refer to the docs <https://pytorch-
|
||||
lightning.readthedocs.io/en/latest/strategies/collaborative_training.html>`__.
|
||||
|
||||
.. warning:: ``CollaborativeStrategy`` is experimental and subject to change.
|
||||
|
||||
Arguments:
|
||||
|
||||
target_batch_size: When training, the batch size to accumulate to before running a step. The larger this
|
||||
batch size, the more work can be done asynchronously without communication.
|
||||
|
||||
run_id: A unique identifier of this training run, used as a common prefix for all DHT keys.
|
||||
See ``https://learning-at-home.readthedocs.io/en/latest/user/dht.html``.
|
||||
|
||||
batch_size: The local batch size per process. If not provided, we infer this from the first batch of data
|
||||
passed in at training (lazy). Note that this should not change throughout training.
|
||||
|
||||
delay_state_averaging: If enabled (default), average parameters and extra tensors in a background thread;
|
||||
if set to False, average parameters synchronously within the
|
||||
corresponding :meth:`hivemind.Optimizer.step` call.
|
||||
|
||||
delay_optimizer_step: Run optimizer in background, apply results in future .step. requires
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer`.
|
||||
|
||||
delay_grad_averaging: Average gradients in background; requires
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.offload_optimizer` and
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.delay_optimizer_step`.
|
||||
|
||||
offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients.
|
||||
|
||||
reuse_grad_buffers: Use the model's gradient buffers (params.grad) for gradient accumulation
|
||||
which is more memory efficient. Lightning will automatically disable ``zero_grad``
|
||||
in the ``LightningModule``.
|
||||
|
||||
scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
|
||||
When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn``
|
||||
is required to be passed to the ``CollaborativeStrategy``. This is because the optimizer
|
||||
is re-created and the scheduler needs to be re-created as well.
|
||||
|
||||
matchmaking_time: When looking for group, wait for peers to join for up to this many seconds.
|
||||
Increase if you see "averaged gradients with N peers" where N is below 0.9x on >=25% of epochs.
|
||||
Training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
|
||||
|
||||
averaging_timeout: If an averaging step hangs for this long, it will be cancelled automatically.
|
||||
Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
|
||||
Do not set this timeout too high, as it may cause your optimizer to hang
|
||||
after some types of network errors.
|
||||
|
||||
verbose: Report internal Hivemind events such as accumulating gradients and running background tasks.
|
||||
|
||||
averager_opts: Additional keyword arguments forwarded to both
|
||||
``GradientAverager`` and ``TrainingStateAverager``.
|
||||
|
||||
host_maddrs: List of multi-addrs to create visible peers for other processes.
|
||||
`https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet`
|
||||
|
||||
initial_peers: If connecting to a running process, a list of initial peers needs to be passed in.
|
||||
This can also be set via the env variable ``INITIAL_PEERS``.
|
||||
|
||||
endpoint: Enable if a side-car endpoint server is required on the process to server initial peers.
|
||||
This is useful when using some form of orchestration such as torchelastic.
|
||||
|
||||
peer_endpoint: The endpoint to request initial peers from.
|
||||
|
||||
persistent: When using an endpoint, this controls whether other processes that are not the endpoint
|
||||
server log/checkpoint. If ``persistent`` is True, we do not log/checkpoint from other processes.
|
||||
|
||||
host: When creating the endpoint, the host IP to use.
|
||||
|
||||
port: When creating the endpoint, the host port to use.
|
||||
|
||||
retry_endpoint_attempts: When connecting to the
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`,
|
||||
how many time to retry before raising an exception.
|
||||
|
||||
retry_endpoint_sleep_duration: When connecting to the
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`,
|
||||
how long to wait between retries.
|
||||
|
||||
**optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class.
|
||||
"""
|
||||
if not _HIVEMIND_AVAILABLE or platform.system() != "Linux":
|
||||
raise MisconfigurationException(
|
||||
"To use the `CollaborativeStrategy`, you must have Hivemind installed and be running on Linux."
|
||||
" Install it by running `pip install -U hivemind`."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.dht_manager = DHTManager(
|
||||
persistent=persistent,
|
||||
endpoint=endpoint,
|
||||
peer_endpoint=peer_endpoint,
|
||||
host=host,
|
||||
port=port,
|
||||
host_maddrs=host_maddrs,
|
||||
initial_peers=initial_peers,
|
||||
retry_endpoint_attempts=retry_endpoint_attempts,
|
||||
retry_endpoint_sleep_duration=retry_endpoint_sleep_duration,
|
||||
)
|
||||
self._target_batch_size = target_batch_size
|
||||
self._batch_size = batch_size
|
||||
self._scheduler_fn = scheduler_fn
|
||||
self._require_scheduler_fn = delay_optimizer_step or delay_state_averaging or offload_optimizer
|
||||
self._opt = None
|
||||
self._optimizer_zero_grad_original: Optional[Callable] = None
|
||||
self._run_id = run_id
|
||||
self._reuse_grad_buffers = reuse_grad_buffers
|
||||
self._optimizer_kwargs = dict(
|
||||
matchmaking_time=matchmaking_time,
|
||||
averaging_timeout=averaging_timeout,
|
||||
delay_optimizer_step=delay_optimizer_step,
|
||||
delay_state_averaging=delay_state_averaging,
|
||||
delay_grad_averaging=delay_grad_averaging,
|
||||
offload_optimizer=offload_optimizer,
|
||||
averager_opts=averager_opts if averaging_timeout is not None else dict(request_timeout=1.0),
|
||||
verbose=verbose,
|
||||
reuse_grad_buffers=reuse_grad_buffers,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
# a bit of a hack to only log from the stable server
|
||||
if self.dht_manager.disable_logging_checkpointing:
|
||||
warnings.warn(
|
||||
"This machine is not a persistent machine. Checkpointing/Logging has been disabled.", UserWarning
|
||||
)
|
||||
rank_zero_only.rank = 1 if self.dht_manager.disable_logging_checkpointing else 0
|
||||
self._hivemind_initialized = False
|
||||
|
||||
@property
|
||||
def num_peers(self) -> int:
|
||||
if self._opt:
|
||||
return self._opt.tracker.global_progress.num_peers
|
||||
return 1
|
||||
|
||||
@property
|
||||
def dht(self) -> "hivemind.DHT":
|
||||
"""Hivemind Distributed Hash Table which stores values across all peers.
|
||||
|
||||
See documentation for more details: `https://learning-at-home.readthedocs.io/en/latest/modules/dht.html`
|
||||
"""
|
||||
return self.dht_manager.dht
|
||||
|
||||
@property
|
||||
def root_device(self) -> torch.device:
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.accelerators.gpu import GPUAccelerator
|
||||
|
||||
if isinstance(self.accelerator, GPUAccelerator):
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
elif isinstance(self.accelerator, CPUAccelerator):
|
||||
return torch.device("cpu")
|
||||
raise MisconfigurationException(
|
||||
f"Was unable to infer device type from the accelerator: {self.accelerator.__class__.__name__}."
|
||||
)
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def is_global_zero(self) -> bool:
|
||||
return True
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.model_to_device()
|
||||
super().setup(trainer)
|
||||
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
|
||||
self.precision_plugin.scaler = hivemind.GradScaler()
|
||||
|
||||
def _initialize_hivemind(self) -> None:
|
||||
if len(self.optimizers) > 1:
|
||||
raise MisconfigurationException("Hivemind only supports training with one optimizer.")
|
||||
optimizer = self.optimizers[0]
|
||||
|
||||
if self._require_scheduler_fn and self._scheduler_fn is None:
|
||||
rank_zero_warn(
|
||||
"Enabling `delay_optimizer_step`, `delay_state_averaging` or `offload_optimizer` "
|
||||
"requires a `scheduler_fn` to be passed to the strategy if a scheduler is being used "
|
||||
"(this is because the optimizer is re-created within Hivemind)."
|
||||
)
|
||||
|
||||
scheduler = self._scheduler_fn if self._require_scheduler_fn else None
|
||||
params = optimizer.param_groups if self._require_scheduler_fn else None
|
||||
optimizer = type(optimizer) if self._require_scheduler_fn else optimizer
|
||||
opt = hivemind.Optimizer(
|
||||
dht=self.dht,
|
||||
run_id=self._run_id,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
target_batch_size=self._target_batch_size,
|
||||
batch_size_per_step=self._batch_size,
|
||||
**self._optimizer_kwargs,
|
||||
)
|
||||
|
||||
if not self._scheduler_fn:
|
||||
self._wrap_schedulers(opt)
|
||||
opt.load_state_from_peers()
|
||||
self.optimizers = [opt]
|
||||
self._opt = opt
|
||||
|
||||
if self._reuse_grad_buffers:
|
||||
assert self.lightning_module is not None
|
||||
self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad
|
||||
self._disable_zero_grad()
|
||||
|
||||
def _disable_zero_grad(self) -> None:
|
||||
lightning_module = self.lightning_module
|
||||
if is_overridden("optimizer_zero_grad", lightning_module):
|
||||
assert lightning_module is not None # `is_overridden` returns False otherwise
|
||||
rank_zero_warn(
|
||||
"You have overridden `optimizer_zero_grad` which will be disabled."
|
||||
" When `CollaborativeStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad,"
|
||||
" as this would delete the gradients before they are averaged."
|
||||
)
|
||||
assert lightning_module is not None
|
||||
lightning_module.optimizer_zero_grad = None # type: ignore[assignment]
|
||||
|
||||
def _wrap_schedulers(self, opt: "hivemind.Optimizer") -> None:
|
||||
# wrap schedulers so that they only update when the hivemind optimizer updates
|
||||
for scheduler_config in self.lr_scheduler_configs:
|
||||
scheduler = scheduler_config.scheduler
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
raise ValueError(
|
||||
f"The `ReduceLROnPlateau` scheduler is not currently supported with `{self.__class__.__name__}`."
|
||||
)
|
||||
scheduler_config.scheduler = HiveMindScheduler(
|
||||
optimizer=opt,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
||||
if not self._hivemind_initialized:
|
||||
self._hivemind_initialized = True
|
||||
# todo (sean): we could technically support a dynamic batch size by inferring each step
|
||||
# and passing it to the ``hivemind.Optimizer``.
|
||||
if self._batch_size is None:
|
||||
try:
|
||||
self._batch_size = extract_batch_size(batch)
|
||||
log.info(f"Found per machine batch size automatically from the batch: {self._batch_size}")
|
||||
except (MisconfigurationException, RecursionError) as e:
|
||||
raise MisconfigurationException(
|
||||
"We tried to infer the batch size from the first batch of data. "
|
||||
"Please provide the batch size to the Strategy by "
|
||||
"``Trainer(strategy=CollaborativeStrategy(batch_size=x))``. "
|
||||
) from e
|
||||
self._initialize_hivemind()
|
||||
|
||||
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
|
||||
return tensor
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
return tensor
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
assert self.model is not None
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def barrier(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
return obj
|
||||
|
||||
def teardown(self) -> None:
|
||||
super().teardown()
|
||||
|
||||
if self._optimizer_zero_grad_original is not None and self.lightning_module is not None:
|
||||
# re-enable `optimizer_zero_grad`
|
||||
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment]
|
||||
|
||||
if self.root_device.type == "cuda" and self.lightning_module is not None:
|
||||
# GPU teardown
|
||||
self.lightning_module.cpu()
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
if self._opt:
|
||||
self._opt.shutdown()
|
||||
log.info("Shutting down hivemind DHT.")
|
||||
self.dht.shutdown()
|
||||
|
||||
|
||||
class HiveMindScheduler:
|
||||
"""Wrapper for schedulers to prevent Lightning from stepping the scheduler too soon.
|
||||
|
||||
This code ensures that we only step when the HiveMind optimizer reaches the global step.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
|
||||
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
|
||||
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
|
||||
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
self.current_step = -1
|
||||
|
||||
def step(self, epoch: Optional[int] = None) -> None:
|
||||
while self.current_step < self.optimizer.local_epoch:
|
||||
self.scheduler.step(epoch=epoch)
|
||||
self.current_step += 1
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
self.scheduler.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
return self.scheduler.state_dict()
|
||||
|
||||
|
||||
class DHTManager:
|
||||
ENDPOINT_ENV: str = "PL_ENDPOINT"
|
||||
PEER_ENDPOINT_ENV: str = "PL_PEER_ENDPOINT"
|
||||
INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS"
|
||||
HOST_ENV: str = "PL_HOST"
|
||||
PORT_ENV: str = "PL_PORT"
|
||||
DEFAULT_HOST: str = "0.0.0.0"
|
||||
DEFAULT_PORT: int = 1440
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_maddrs: Optional[List],
|
||||
initial_peers: Optional[Union[str, List]],
|
||||
persistent: bool,
|
||||
endpoint: Optional[bool],
|
||||
peer_endpoint: Optional[str],
|
||||
host: Optional[str],
|
||||
port: Optional[int],
|
||||
retry_endpoint_attempts: int = 5,
|
||||
retry_endpoint_sleep_duration: int = 5,
|
||||
) -> None:
|
||||
"""Manages the `hivemind.DHT` connection and provides a side-car endpoint server for initial peer access.
|
||||
|
||||
Arguments:
|
||||
|
||||
host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host_maddrs`
|
||||
|
||||
initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.initial_peers`
|
||||
|
||||
persistent: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.persistent`
|
||||
|
||||
endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.endpoint`
|
||||
|
||||
peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.peer_endpoint`
|
||||
|
||||
host: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.host`
|
||||
|
||||
port: :paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.port`
|
||||
|
||||
retry_endpoint_attempts:
|
||||
:paramref:`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_attempts`
|
||||
|
||||
retry_endpoint_sleep_duration:
|
||||
:paramref:
|
||||
`~pytorch_lightning.strategies.collaborative.CollaborativeStrategy.retry_endpoint_sleep_duration`
|
||||
"""
|
||||
self._persistent = persistent
|
||||
self._endpoint = endpoint
|
||||
self._initial_peers = initial_peers
|
||||
self._peer_endpoint = peer_endpoint
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
self._parse_env_vars()
|
||||
|
||||
if self._peer_endpoint and self._initial_peers is None:
|
||||
self._initial_peers = self._get_initial_peers_from_endpoint(
|
||||
retry_initial_peers=retry_endpoint_attempts, retry_peer_sleep_duration=retry_endpoint_sleep_duration
|
||||
)
|
||||
|
||||
self.dht = hivemind.DHT(
|
||||
start=True,
|
||||
initial_peers=self._initial_peers,
|
||||
host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
|
||||
)
|
||||
|
||||
visible_addresses = [
|
||||
str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
|
||||
]
|
||||
|
||||
if self._endpoint:
|
||||
self._host = self._host if self._host is not None else self.DEFAULT_HOST
|
||||
self._port = self._port if self._port is not None else self.DEFAULT_PORT
|
||||
self._start_server_process(self._host, self._port)
|
||||
self._log_endpoint_helper_message(visible_addresses)
|
||||
elif self._peer_endpoint:
|
||||
log.info("Machine received initial peers from endpoint.")
|
||||
elif self._initial_peers is None:
|
||||
log.info(
|
||||
"\nOther machines can connect running the same command:\n"
|
||||
f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n"
|
||||
"or passing the peers to the strategy:\n"
|
||||
f"CollaborativeStrategy(initial_peers='{','.join(visible_addresses)}')"
|
||||
)
|
||||
|
||||
def _log_endpoint_helper_message(self, visible_addresses: List[str]) -> None:
|
||||
assert self._host is not None
|
||||
resolved_host = self._host
|
||||
if "0.0.0.0" in self._host:
|
||||
# use the visible multi-addresses to figure out the IP that has been exposed
|
||||
# todo (sean): this is pretty hacky, worth investigating.
|
||||
p = re.compile(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
|
||||
# todo (sean): we select one address from here, could we have multiple?
|
||||
resolved_host = {p.findall(maddr)[0] for maddr in visible_addresses}.pop()
|
||||
log.info(
|
||||
"\nSidecar endpoint enabled to serve peers.\n"
|
||||
"Other peers can connect via:\n"
|
||||
f"PEER_ENDPOINT={resolved_host}:{self._port} python ...\n"
|
||||
"or pass the peer endpoint address to the strategy:\n"
|
||||
f"CollaborativeStrategy(peer_endpoint='{resolved_host}:{self._port}')"
|
||||
)
|
||||
|
||||
def _start_server_process(self, host: str, port: int) -> None:
|
||||
dht = self.dht
|
||||
|
||||
class DHTHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
"""Respond to a GET request."""
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
|
||||
visible_peers = [
|
||||
str(a) for a in dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
|
||||
]
|
||||
|
||||
self.wfile.write("\n".join(visible_peers).encode())
|
||||
|
||||
server = http.server.ThreadingHTTPServer((host, int(port)), DHTHandler)
|
||||
thread = threading.Thread(target=server.serve_forever)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
def _get_initial_peers_from_endpoint(self, retry_initial_peers: int, retry_peer_sleep_duration: int) -> List:
|
||||
peers = None
|
||||
for _ in range(retry_initial_peers):
|
||||
try:
|
||||
peers = self._get_peers()
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
log.info(f"Failed to get peers, retrying in {retry_peer_sleep_duration} seconds...")
|
||||
time.sleep(retry_peer_sleep_duration)
|
||||
if peers is None:
|
||||
raise MisconfigurationException(
|
||||
f"Unable to get peers. Tried {retry_initial_peers} times waiting {retry_peer_sleep_duration}s."
|
||||
f"These parameters can be extended by passing "
|
||||
"to the strategy (CollaborativeStrategy(retry_connection=x, retry_sleep_duration=y))."
|
||||
)
|
||||
log.info(f"Received initial peers from collaborative server: {peers}")
|
||||
return peers
|
||||
|
||||
def _get_peers(self) -> List[str]:
|
||||
assert self._peer_endpoint is not None
|
||||
url = f"http://{self._peer_endpoint}" if not self._peer_endpoint.startswith("http://") else self._peer_endpoint
|
||||
r = requests.get(url)
|
||||
return r.text.split(",")
|
||||
|
||||
def _parse_env_vars(self) -> None:
|
||||
endpoint = os.environ.get(self.ENDPOINT_ENV, self._endpoint)
|
||||
self._endpoint = endpoint == "1" if isinstance(endpoint, str) else endpoint
|
||||
self._peer_endpoint = os.environ.get(self.PEER_ENDPOINT_ENV, self._peer_endpoint)
|
||||
initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers)
|
||||
self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else initial_peers
|
||||
|
||||
port = os.environ.get(self.PORT_ENV, self._port)
|
||||
self._port = int(port) if isinstance(port, str) else port
|
||||
self._host = os.environ.get(self.HOST_ENV, self._host)
|
||||
|
||||
@property
|
||||
def disable_logging_checkpointing(self) -> bool:
|
||||
# if this node is a peer, we do not log/checkpoint in persistent mode.
|
||||
return self._persistent and (self._initial_peers is not None or self._peer_endpoint is not None)
|
|
@ -35,6 +35,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
|
||||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
|
||||
_GROUP_AVAILABLE,
|
||||
_HIVEMIND_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_HPU_AVAILABLE,
|
||||
_HYDRA_AVAILABLE,
|
||||
|
|
|
@ -105,6 +105,7 @@ _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
|
|||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
|
||||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
|
||||
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.group")
|
||||
_HIVEMIND_AVAILABLE = _package_available("hivemind")
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _package_available("hydra")
|
||||
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
||||
|
|
|
@ -68,6 +68,9 @@ class _LRScheduler(_Stateful, Protocol):
|
|||
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
def step(self, epoch: Optional[int] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
|
@ -91,6 +94,9 @@ class ReduceLROnPlateau(_Stateful, Protocol):
|
|||
) -> None:
|
||||
...
|
||||
|
||||
def step(self, metrics: Union[float, int, torch.Tensor], epoch: Optional[int] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
# todo: improve LRSchedulerType naming/typing
|
||||
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
fairscale>=0.4.5
|
||||
deepspeed<0.6.0
|
||||
horovod>=0.21.2,!=0.24.0 # no need to install with [pytorch] as pytorch is already installed
|
||||
hivemind>=1.0.1; sys_platform == 'linux'
|
||||
|
|
|
@ -26,6 +26,7 @@ from pytorch_lightning.utilities import (
|
|||
_DEEPSPEED_AVAILABLE,
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
|
||||
_HIVEMIND_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_HPU_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
|
@ -84,6 +85,7 @@ class RunIf:
|
|||
omegaconf: bool = False,
|
||||
slow: bool = False,
|
||||
bagua: bool = False,
|
||||
hivemind: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -111,6 +113,7 @@ class RunIf:
|
|||
omegaconf: Require that omry/omegaconf is installed.
|
||||
slow: Mark the test as slow, our CI will run it in a separate job.
|
||||
bagua: Require that BaguaSys/bagua is installed.
|
||||
hivemind: Require that Hivemind is installed.
|
||||
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
|
||||
"""
|
||||
conditions = []
|
||||
|
@ -231,6 +234,10 @@ class RunIf:
|
|||
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
|
||||
reasons.append("Bagua")
|
||||
|
||||
if hivemind:
|
||||
conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin"))
|
||||
reasons.append("Hivemind")
|
||||
|
||||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||||
return pytest.mark.skipif(
|
||||
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
|
||||
|
|
|
@ -0,0 +1,372 @@
|
|||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
from unittest.mock import PropertyMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
|
||||
from pytorch_lightning.strategies import CollaborativeStrategy
|
||||
from pytorch_lightning.strategies.collaborative import HiveMindScheduler
|
||||
from pytorch_lightning.utilities import _HIVEMIND_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _HIVEMIND_AVAILABLE:
|
||||
import hivemind
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.strategies.collaborative._HIVEMIND_AVAILABLE", False)
|
||||
def test_raise_exception_if_hivemind_unavailable():
|
||||
"""Test that we raise an exception when Hivemind is not available."""
|
||||
with pytest.raises(MisconfigurationException, match="you must have Hivemind installed"):
|
||||
CollaborativeStrategy(target_batch_size=1)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch("hivemind.DHT", autospec=True)
|
||||
def test_strategy(mock_dht):
|
||||
strategy = CollaborativeStrategy(target_batch_size=1)
|
||||
trainer = pl.Trainer(strategy=strategy)
|
||||
assert trainer.strategy == strategy
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch("hivemind.DHT", autospec=True)
|
||||
@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True)
|
||||
@pytest.mark.parametrize(
|
||||
"initial_peers,peer_endpoint",
|
||||
[(["TEST"], None), (None, "localhost:153")],
|
||||
)
|
||||
def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, peer_endpoint):
|
||||
"""Test when we are a second peer (passing initial peers or peer endpoint) we warn the user that
|
||||
logging/checkpointing will be disabled."""
|
||||
with pytest.warns(UserWarning, match="This machine is not a persistent machine"):
|
||||
CollaborativeStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", "PL_PORT": str(find_free_network_port())},
|
||||
clear=True,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint,expected_message",
|
||||
[(False, "INITIAL_PEERS"), (True, "Sidecar endpoint enabled to serve peers.")],
|
||||
)
|
||||
def test_initial_peer_message(caplog, endpoint, expected_message):
|
||||
model = BoringModel()
|
||||
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
assert expected_message in caplog.text
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
def test_optimizer_wrapped():
|
||||
class TestModel(BoringModel):
|
||||
def on_before_backward(self, loss: torch.Tensor) -> None:
|
||||
optimizer = self.trainer.optimizers[0]
|
||||
assert isinstance(optimizer, hivemind.Optimizer)
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
def test_scheduler_wrapped():
|
||||
class TestModel(BoringModel):
|
||||
def on_before_backward(self, loss: torch.Tensor) -> None:
|
||||
scheduler = self.trainer.lr_scheduler_configs[0].scheduler
|
||||
assert isinstance(scheduler, HiveMindScheduler)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
return [optimizer], [torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)]
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(
|
||||
strategy=CollaborativeStrategy(target_batch_size=1),
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor",
|
||||
"PL_INITIAL_PEERS": "TEST_PEERS",
|
||||
"PL_HOST": "TEST_HOST",
|
||||
"PL_PORT": "1300",
|
||||
"PL_ENDPOINT": "1",
|
||||
"PL_PEER_ENDPOINT": "TEST_PEER_ENDPOINT",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
@mock.patch("hivemind.DHT", autospec=True)
|
||||
@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True)
|
||||
@mock.patch("http.server.ThreadingHTTPServer", autospec=True)
|
||||
def test_env_variables_parsed(mock_dht, mock_peers, mock_server):
|
||||
"""Test that env variables are parsed correctly."""
|
||||
strategy = CollaborativeStrategy(target_batch_size=1)
|
||||
assert strategy.dht_manager._initial_peers == ["TEST_PEERS"]
|
||||
assert strategy.dht_manager._host == "TEST_HOST"
|
||||
assert strategy.dht_manager._port == 1300
|
||||
assert strategy.dht_manager._endpoint
|
||||
assert strategy.dht_manager._peer_endpoint == "TEST_PEER_ENDPOINT"
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
def test_reuse_grad_buffers_warning():
|
||||
"""Test to ensure we warn when a user overrides `optimizer_zero_grad` and `reuse_grad_buffers` is True."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def on_before_backward(self, loss: torch.Tensor) -> None:
|
||||
optimizer = self.trainer.optimizers[0]
|
||||
assert isinstance(optimizer, hivemind.Optimizer)
|
||||
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
||||
pass
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(
|
||||
strategy=CollaborativeStrategy(target_batch_size=1, reuse_grad_buffers=True), fast_dev_run=True
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match="You have overridden `optimizer_zero_grad` which will be disabled."):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
def test_raise_exception_multiple_optimizers():
|
||||
"""Test that we raise an exception when multiple optimizers are provided."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer, optimizer], [lr_scheduler]
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="Hivemind only supports training with one optimizer."):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch("pytorch_lightning.utilities.data._extract_batch_size", autospec=True, return_value=[None])
|
||||
def test_raise_exception_no_batch_size(mock_extract_batch_size):
|
||||
"""Test that we raise an exception when no batch size is automatically found."""
|
||||
|
||||
model = BoringModel()
|
||||
trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1), fast_dev_run=True)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="Please provide the batch size to the Strategy."):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
@pytest.mark.parametrize(
|
||||
"delay_grad_averaging, delay_state_averaging, delay_optimizer_step",
|
||||
[(True, True, True), (False, True, False)],
|
||||
)
|
||||
def test_warn_if_argument_passed(delay_grad_averaging, delay_state_averaging, delay_optimizer_step):
|
||||
"""Test ensures that valid combination of HiveMind delay arguments warn if scheduler isn't passed in as a
|
||||
function."""
|
||||
model = BoringModel()
|
||||
trainer = pl.Trainer(
|
||||
strategy=CollaborativeStrategy(
|
||||
target_batch_size=1,
|
||||
delay_grad_averaging=delay_grad_averaging,
|
||||
delay_state_averaging=delay_state_averaging,
|
||||
delay_optimizer_step=delay_optimizer_step,
|
||||
),
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match="requires a `scheduler_fn` to be passed to the strategy"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
@mock.patch("http.server.ThreadingHTTPServer", autospec=True)
|
||||
@mock.patch("pytorch_lightning.strategies.collaborative.CollaborativeStrategy.num_peers", new_callable=PropertyMock)
|
||||
def test_args_passed_to_optimizer(mock_peers, mock_server):
|
||||
"""Test to ensure arguments are correctly passed to the hivemind optimizer wrapper."""
|
||||
mock_peers.return_value = 1
|
||||
compression = hivemind.ScaledFloat16Compression()
|
||||
with mock.patch("hivemind.Optimizer", wraps=hivemind.Optimizer) as mock_optimizer:
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def on_before_backward(self, loss: torch.Tensor) -> None:
|
||||
args, kwargs = mock_optimizer.call_args
|
||||
mock_optimizer.assert_called()
|
||||
arguments = dict(
|
||||
delay_optimizer_step=True,
|
||||
delay_state_averaging=True,
|
||||
state_averaging_compression=compression,
|
||||
grad_compression=compression,
|
||||
offload_optimizer=True,
|
||||
reuse_grad_buffers=True,
|
||||
target_batch_size=1,
|
||||
)
|
||||
|
||||
for key, value in arguments.items():
|
||||
assert key in kwargs
|
||||
assert value == kwargs[key]
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(
|
||||
strategy=CollaborativeStrategy(
|
||||
target_batch_size=1,
|
||||
reuse_grad_buffers=True,
|
||||
delay_state_averaging=True,
|
||||
delay_optimizer_step=True,
|
||||
offload_optimizer=True,
|
||||
grad_compression=compression,
|
||||
state_averaging_compression=compression,
|
||||
),
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
# ensures that after training with `reuse_grad_buffers` we restore the hook
|
||||
assert model.optimizer_zero_grad is not None
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
@pytest.mark.parametrize(
|
||||
"host_maddrs,expected_maddrs",
|
||||
[(None, ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"]), (["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"])],
|
||||
)
|
||||
def test_maddrs(host_maddrs, expected_maddrs):
|
||||
"""Test that the multiple addresses are correctly assigned."""
|
||||
strategy = CollaborativeStrategy(target_batch_size=1, host_maddrs=host_maddrs)
|
||||
assert strategy.dht.kwargs["host_maddrs"] == expected_maddrs
|
||||
|
||||
|
||||
def _run_collab_training_fn(initial_peers, wait_seconds, barrier, recorded_process_peers, recorded_process_steps):
|
||||
recorded_peers = []
|
||||
recorded_global_steps = []
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None:
|
||||
time.sleep(wait_seconds) # add an additional delay to give processes time to sync
|
||||
recorded_peers.append(self.trainer.strategy.num_peers)
|
||||
recorded_global_steps.append(self.trainer.optimizers[0].local_epoch)
|
||||
|
||||
def on_train_end(self) -> None:
|
||||
# wait for all processes to get to the end of training before teardown
|
||||
barrier.wait()
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
limit_train_batches=16,
|
||||
limit_val_batches=0,
|
||||
strategy=CollaborativeStrategy(
|
||||
delay_state_averaging=True,
|
||||
offload_optimizer=True,
|
||||
delay_optimizer_step=True,
|
||||
delay_grad_averaging=True,
|
||||
target_batch_size=8,
|
||||
initial_peers=initial_peers,
|
||||
verbose=False,
|
||||
),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
recorded_process_peers.append(recorded_peers)
|
||||
recorded_process_steps.append(recorded_global_steps)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
@pytest.mark.parametrize(
|
||||
"num_processes, wait_seconds",
|
||||
[(2, 0.25)],
|
||||
)
|
||||
def test_multiple_peers(num_processes, wait_seconds):
|
||||
"""Test to ensure that if we have two running processes with the same peers, they connect and train
|
||||
successfully."""
|
||||
dht_root = hivemind.DHT(start=True)
|
||||
barrier = mp.Barrier(num_processes)
|
||||
initial_peers = dht_root.get_visible_maddrs()
|
||||
|
||||
with mp.Manager() as manager:
|
||||
# allows processes to return their recorded logged peers/steps
|
||||
recorded_process_peers = manager.list()
|
||||
recorded_process_steps = manager.list()
|
||||
processes = [
|
||||
mp.Process(
|
||||
target=_run_collab_training_fn,
|
||||
kwargs=dict(
|
||||
initial_peers=initial_peers,
|
||||
wait_seconds=wait_seconds,
|
||||
barrier=barrier,
|
||||
recorded_process_peers=recorded_process_peers,
|
||||
recorded_process_steps=recorded_process_steps,
|
||||
),
|
||||
)
|
||||
for x in range(num_processes)
|
||||
]
|
||||
for process in processes:
|
||||
process.start()
|
||||
for process in processes:
|
||||
process.join()
|
||||
# assert that peers increase as expected and we run at-least 1 global step.
|
||||
for process_peers, process_steps in zip(recorded_process_peers, recorded_process_steps):
|
||||
assert any(num_peer == num_processes for num_peer in process_peers)
|
||||
assert any(global_step > 0 for global_step in process_steps)
|
||||
|
||||
|
||||
@RunIf(hivemind=True, min_gpus=1)
|
||||
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)
|
||||
def test_scaler_updated_precision_16():
|
||||
class TestModel(BoringModel):
|
||||
def on_fit_start(self) -> None:
|
||||
assert isinstance(self.trainer.precision_plugin.scaler, hivemind.GradScaler)
|
||||
raise SystemExit
|
||||
|
||||
model = TestModel()
|
||||
trainer = pl.Trainer(
|
||||
strategy=CollaborativeStrategy(target_batch_size=1),
|
||||
fast_dev_run=True,
|
||||
precision=16,
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(hivemind=True)
|
||||
def test_raise_when_peer_endpoint_unsuccessful(caplog):
|
||||
port = find_free_network_port()
|
||||
with pytest.raises(MisconfigurationException, match="Unable to get peers"):
|
||||
with mock.patch("requests.get", wraps=requests.get) as requests_mock:
|
||||
CollaborativeStrategy(
|
||||
target_batch_size=1,
|
||||
peer_endpoint=f"localhost:{port}",
|
||||
retry_endpoint_attempts=10,
|
||||
retry_endpoint_sleep_duration=0,
|
||||
)
|
||||
assert "Failed to get peers, retrying" in caplog.text
|
||||
assert requests_mock.call_count == 10
|
Loading…
Reference in New Issue