Add MPI cluster environment (#16570)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e20172d370
commit
0f75dce8b4
|
@ -139,6 +139,7 @@ environments
|
|||
KubeflowEnvironment
|
||||
LightningEnvironment
|
||||
LSFEnvironment
|
||||
MPIEnvironment
|
||||
SLURMEnvironment
|
||||
TorchElasticEnvironment
|
||||
XLAEnvironment
|
||||
|
|
|
@ -89,6 +89,7 @@ Environments
|
|||
~kubeflow.KubeflowEnvironment
|
||||
~lightning.LightningEnvironment
|
||||
~lsf.LSFEnvironment
|
||||
~mpi.MPIEnvironment
|
||||
~slurm.SLURMEnvironment
|
||||
~torchelastic.TorchElasticEnvironment
|
||||
~xla.XLAEnvironment
|
||||
|
|
|
@ -180,12 +180,20 @@ Choose from the following options based on your expertise level and available in
|
|||
|
||||
.. displayitem::
|
||||
:header: Bare Bones Cluster
|
||||
:description: Train across machines on a network.
|
||||
:description: Train across machines on a network using `torchrun`.
|
||||
:col_css: col-md-4
|
||||
:button_link: ../guide/multi_node/barebones.html
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
||||
.. displayitem::
|
||||
:header: Other Cluster Environments
|
||||
:description: MPI, LSF, Kubeflow
|
||||
:col_css: col-md-4
|
||||
:button_link: ../guide/multi_node/other.html
|
||||
:height: 160
|
||||
:tag: advanced
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
:orphan:
|
||||
|
||||
##########################
|
||||
Other Cluster Environments
|
||||
##########################
|
||||
|
||||
**Audience**: Users who want to run on a cluster that launches the training script via MPI, LSF, Kubeflow, etc.
|
||||
|
||||
Lightning automates the details behind training on the most common cluster environments.
|
||||
While :doc:`SLURM <./slurm>` is the most popular choice for on-prem clusters, there are other systems that Lightning can detect automatically.
|
||||
|
||||
Don't have access to an enterprise cluster? Try the :doc:`Lightning cloud <./cloud>`.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
***
|
||||
MPI
|
||||
***
|
||||
|
||||
`MPI (Message Passing Interface) <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ is a communication system for parallel computing.
|
||||
There are many implementations available, the most popular among them are `OpenMPI <https://www.open-mpi.org/>`_ and `MPICH <https://www.mpich.org/>`_.
|
||||
To support all these, Lightning relies on the `mpi4py package <https://github.com/mpi4py/mpi4py>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install mpi4py
|
||||
|
||||
If the package is installed and the Python script gets launched by MPI, Fabric will automatically detect it and parse the process information from the environment.
|
||||
There is nothing you have to change in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
fabric = Fabric(...) # automatically detects MPI
|
||||
print(fabric.world_size) # world size provided by MPI
|
||||
print(fabric.global_rank) # rank provided by MPI
|
||||
...
|
||||
|
||||
If you want to bypass the automatic detection, you can explicitly set the MPI environment as a plugin:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from lightning.fabric.plugins.environments import MPIEnvironment
|
||||
|
||||
fabric = Fabric(..., plugins=[MPIEnvironment()])
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
***
|
||||
LSF
|
||||
***
|
||||
|
||||
Coming soon.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
********
|
||||
Kubeflow
|
||||
********
|
||||
|
||||
Coming soon.
|
|
@ -36,6 +36,7 @@ from lightning.fabric.plugins.environments import (
|
|||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
LSFEnvironment,
|
||||
MPIEnvironment,
|
||||
SLURMEnvironment,
|
||||
TorchElasticEnvironment,
|
||||
)
|
||||
|
@ -374,6 +375,7 @@ class _Connector:
|
|||
TorchElasticEnvironment,
|
||||
KubeflowEnvironment,
|
||||
LSFEnvironment,
|
||||
MPIEnvironment,
|
||||
):
|
||||
if env_type.detect():
|
||||
return env_type()
|
||||
|
@ -414,6 +416,7 @@ class _Connector:
|
|||
or KubeflowEnvironment.detect()
|
||||
or SLURMEnvironment.detect()
|
||||
or LSFEnvironment.detect()
|
||||
or MPIEnvironment.detect()
|
||||
):
|
||||
strategy_flag = "ddp"
|
||||
if strategy_flag == "dp" and self._accelerator_flag == "cpu":
|
||||
|
|
|
@ -15,6 +15,7 @@ from lightning.fabric.plugins.environments.cluster_environment import ClusterEnv
|
|||
from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.lightning import LightningEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.lsf import LSFEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.mpi import MPIEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment # noqa: F401
|
||||
from lightning.fabric.plugins.environments.xla import XLAEnvironment # noqa: F401
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright The Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from lightning.fabric.plugins.environments.lightning import find_free_network_port
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_MPI4PY_AVAILABLE = RequirementCache("mpi4py")
|
||||
|
||||
|
||||
class MPIEnvironment(ClusterEnvironment):
|
||||
"""An environment for running on clusters with processes created through MPI.
|
||||
|
||||
Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
if not _MPI4PY_AVAILABLE:
|
||||
raise ModuleNotFoundError(str(_MPI4PY_AVAILABLE))
|
||||
|
||||
from mpi4py import MPI
|
||||
|
||||
self._comm_world = MPI.COMM_WORLD
|
||||
self._comm_local: Optional[MPI.Comm] = None
|
||||
self._node_rank: Optional[int] = None
|
||||
self._main_address: Optional[str] = None
|
||||
self._main_port: Optional[int] = None
|
||||
|
||||
@property
|
||||
def creates_processes_externally(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def main_address(self) -> str:
|
||||
if self._main_address is None:
|
||||
self._main_address = self._get_main_address()
|
||||
return self._main_address
|
||||
|
||||
@property
|
||||
def main_port(self) -> int:
|
||||
if self._main_port is None:
|
||||
self._main_port = self._get_main_port()
|
||||
return self._main_port
|
||||
|
||||
@staticmethod
|
||||
def detect() -> bool:
|
||||
"""Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1."""
|
||||
if not _MPI4PY_AVAILABLE:
|
||||
return False
|
||||
|
||||
from mpi4py import MPI
|
||||
|
||||
return MPI.COMM_WORLD.Get_size() > 1
|
||||
|
||||
@lru_cache(1)
|
||||
def world_size(self) -> int:
|
||||
return self._comm_world.Get_size()
|
||||
|
||||
def set_world_size(self, size: int) -> None:
|
||||
log.debug("MPIEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
||||
|
||||
@lru_cache(1)
|
||||
def global_rank(self) -> int:
|
||||
return self._comm_world.Get_rank()
|
||||
|
||||
def set_global_rank(self, rank: int) -> None:
|
||||
log.debug("MPIEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
||||
|
||||
@lru_cache(1)
|
||||
def local_rank(self) -> int:
|
||||
if self._comm_local is None:
|
||||
self._init_comm_local()
|
||||
assert self._comm_local is not None
|
||||
return self._comm_local.Get_rank()
|
||||
|
||||
def node_rank(self) -> int:
|
||||
if self._node_rank is None:
|
||||
self._init_comm_local()
|
||||
assert self._node_rank is not None
|
||||
return self._node_rank
|
||||
|
||||
def _get_main_address(self) -> str:
|
||||
return self._comm_world.bcast(socket.gethostname(), root=0)
|
||||
|
||||
def _get_main_port(self) -> int:
|
||||
return self._comm_world.bcast(find_free_network_port(), root=0)
|
||||
|
||||
def _init_comm_local(self) -> None:
|
||||
hostname = socket.gethostname()
|
||||
all_hostnames = self._comm_world.gather(hostname, root=0)
|
||||
# sort all the hostnames, and find unique ones
|
||||
unique_hosts = np.unique(all_hostnames)
|
||||
unique_hosts = self._comm_world.bcast(unique_hosts, root=0)
|
||||
# find the integer for this host in the list of hosts:
|
||||
self._node_rank = int(np.where(unique_hosts == hostname)[0])
|
||||
self._comm_local = self._comm_world.Split(color=self._node_rank)
|
|
@ -16,6 +16,7 @@ from lightning.fabric.plugins.environments import ( # noqa: F401
|
|||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
LSFEnvironment,
|
||||
MPIEnvironment,
|
||||
SLURMEnvironment,
|
||||
TorchElasticEnvironment,
|
||||
XLAEnvironment,
|
||||
|
|
|
@ -25,6 +25,7 @@ from lightning.fabric.plugins.environments import (
|
|||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
LSFEnvironment,
|
||||
MPIEnvironment,
|
||||
SLURMEnvironment,
|
||||
TorchElasticEnvironment,
|
||||
)
|
||||
|
@ -449,6 +450,7 @@ class AcceleratorConnector:
|
|||
TorchElasticEnvironment,
|
||||
KubeflowEnvironment,
|
||||
LSFEnvironment,
|
||||
MPIEnvironment,
|
||||
):
|
||||
if env_type.detect():
|
||||
return env_type()
|
||||
|
@ -499,6 +501,7 @@ class AcceleratorConnector:
|
|||
or KubeflowEnvironment.detect()
|
||||
or SLURMEnvironment.detect()
|
||||
or LSFEnvironment.detect()
|
||||
or MPIEnvironment.detect()
|
||||
):
|
||||
strategy_flag = "ddp"
|
||||
if strategy_flag == "dp" and self._accelerator_flag == "cpu":
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import lightning.fabric.plugins.environments.mpi
|
||||
from lightning.fabric.plugins.environments import MPIEnvironment
|
||||
|
||||
|
||||
def test_dependencies(monkeypatch):
|
||||
"""Test that the MPI environment requires the `mpi4py` package."""
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
MPIEnvironment()
|
||||
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": MagicMock()}):
|
||||
MPIEnvironment()
|
||||
|
||||
|
||||
def test_detect(monkeypatch):
|
||||
"""Test the detection of an MPI environment configuration."""
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
|
||||
assert not MPIEnvironment.detect()
|
||||
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
mpi4py_mock = MagicMock()
|
||||
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
|
||||
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 0
|
||||
assert not MPIEnvironment.detect()
|
||||
|
||||
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 1
|
||||
assert not MPIEnvironment.detect()
|
||||
|
||||
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 2
|
||||
assert MPIEnvironment.detect()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
def test_default_attributes(monkeypatch):
|
||||
"""Test the default attributes when no environment variables are set."""
|
||||
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
mpi4py_mock = MagicMock()
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
|
||||
env = MPIEnvironment()
|
||||
|
||||
assert env._node_rank is None
|
||||
assert env._main_address is None
|
||||
assert env._main_port is None
|
||||
assert env.creates_processes_externally
|
||||
|
||||
|
||||
def test_init_local_comm(monkeypatch):
|
||||
"""Test that it can determine the node rank and local rank based on the hostnames of all participating
|
||||
nodes."""
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
mpi4py_mock = MagicMock()
|
||||
hostname_mock = MagicMock()
|
||||
|
||||
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 4
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}), mock.patch("socket.gethostname", hostname_mock):
|
||||
env = MPIEnvironment()
|
||||
|
||||
hostname_mock.return_value = "host1"
|
||||
env._comm_world.bcast.return_value = np.array(["host1", "host2"])
|
||||
assert env.node_rank() == 0
|
||||
|
||||
env._node_rank = None
|
||||
hostname_mock.return_value = "host2"
|
||||
env._comm_world.bcast.return_value = np.array(["host1", "host2"])
|
||||
assert env.node_rank() == 1
|
||||
|
||||
assert env._comm_local is not None
|
||||
env._comm_local.Get_rank.return_value = 33
|
||||
assert env.local_rank() == 33
|
||||
|
||||
|
||||
def test_world_comm(monkeypatch):
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
mpi4py_mock = MagicMock()
|
||||
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
|
||||
env = MPIEnvironment()
|
||||
|
||||
env._comm_world.Get_size.return_value = 8
|
||||
assert env.world_size() == 8
|
||||
env._comm_world.Get_rank.return_value = 3
|
||||
assert env.global_rank() == 3
|
||||
|
||||
|
||||
def test_setters(monkeypatch, caplog):
|
||||
# pretend mpi4py is available
|
||||
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
|
||||
mpi4py_mock = MagicMock()
|
||||
|
||||
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
|
||||
env = MPIEnvironment()
|
||||
|
||||
# setter should be no-op
|
||||
with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
|
||||
env.set_global_rank(100)
|
||||
assert "setting global rank is not allowed" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
|
||||
with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
|
||||
env.set_world_size(100)
|
||||
assert "setting world size is not allowed" in caplog.text
|
Loading…
Reference in New Issue