Add MPI cluster environment (#16570)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-02-03 11:45:11 +01:00 committed by GitHub
parent e20172d370
commit 0f75dce8b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 332 additions and 1 deletions

View File

@ -139,6 +139,7 @@ environments
KubeflowEnvironment
LightningEnvironment
LSFEnvironment
MPIEnvironment
SLURMEnvironment
TorchElasticEnvironment
XLAEnvironment

View File

@ -89,6 +89,7 @@ Environments
~kubeflow.KubeflowEnvironment
~lightning.LightningEnvironment
~lsf.LSFEnvironment
~mpi.MPIEnvironment
~slurm.SLURMEnvironment
~torchelastic.TorchElasticEnvironment
~xla.XLAEnvironment

View File

@ -180,12 +180,20 @@ Choose from the following options based on your expertise level and available in
.. displayitem::
:header: Bare Bones Cluster
:description: Train across machines on a network.
:description: Train across machines on a network using `torchrun`.
:col_css: col-md-4
:button_link: ../guide/multi_node/barebones.html
:height: 160
:tag: advanced
.. displayitem::
:header: Other Cluster Environments
:description: MPI, LSF, Kubeflow
:col_css: col-md-4
:button_link: ../guide/multi_node/other.html
:height: 160
:tag: advanced
.. raw:: html
</div>

View File

@ -0,0 +1,66 @@
:orphan:
##########################
Other Cluster Environments
##########################
**Audience**: Users who want to run on a cluster that launches the training script via MPI, LSF, Kubeflow, etc.
Lightning automates the details behind training on the most common cluster environments.
While :doc:`SLURM <./slurm>` is the most popular choice for on-prem clusters, there are other systems that Lightning can detect automatically.
Don't have access to an enterprise cluster? Try the :doc:`Lightning cloud <./cloud>`.
----
***
MPI
***
`MPI (Message Passing Interface) <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ is a communication system for parallel computing.
There are many implementations available, the most popular among them are `OpenMPI <https://www.open-mpi.org/>`_ and `MPICH <https://www.mpich.org/>`_.
To support all these, Lightning relies on the `mpi4py package <https://github.com/mpi4py/mpi4py>`_:
.. code-block:: bash
pip install mpi4py
If the package is installed and the Python script gets launched by MPI, Fabric will automatically detect it and parse the process information from the environment.
There is nothing you have to change in your code:
.. code-block:: python
fabric = Fabric(...) # automatically detects MPI
print(fabric.world_size) # world size provided by MPI
print(fabric.global_rank) # rank provided by MPI
...
If you want to bypass the automatic detection, you can explicitly set the MPI environment as a plugin:
.. code-block:: python
from lightning.fabric.plugins.environments import MPIEnvironment
fabric = Fabric(..., plugins=[MPIEnvironment()])
----
***
LSF
***
Coming soon.
----
********
Kubeflow
********
Coming soon.

View File

@ -36,6 +36,7 @@ from lightning.fabric.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
MPIEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
@ -374,6 +375,7 @@ class _Connector:
TorchElasticEnvironment,
KubeflowEnvironment,
LSFEnvironment,
MPIEnvironment,
):
if env_type.detect():
return env_type()
@ -414,6 +416,7 @@ class _Connector:
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()
or LSFEnvironment.detect()
or MPIEnvironment.detect()
):
strategy_flag = "ddp"
if strategy_flag == "dp" and self._accelerator_flag == "cpu":

View File

@ -15,6 +15,7 @@ from lightning.fabric.plugins.environments.cluster_environment import ClusterEnv
from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment # noqa: F401
from lightning.fabric.plugins.environments.lightning import LightningEnvironment # noqa: F401
from lightning.fabric.plugins.environments.lsf import LSFEnvironment # noqa: F401
from lightning.fabric.plugins.environments.mpi import MPIEnvironment # noqa: F401
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment # noqa: F401
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment # noqa: F401
from lightning.fabric.plugins.environments.xla import XLAEnvironment # noqa: F401

View File

@ -0,0 +1,116 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from functools import lru_cache
from typing import Optional
import numpy as np
from lightning_utilities.core.imports import RequirementCache
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.environments.lightning import find_free_network_port
log = logging.getLogger(__name__)
_MPI4PY_AVAILABLE = RequirementCache("mpi4py")
class MPIEnvironment(ClusterEnvironment):
"""An environment for running on clusters with processes created through MPI.
Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py
"""
def __init__(self) -> None:
if not _MPI4PY_AVAILABLE:
raise ModuleNotFoundError(str(_MPI4PY_AVAILABLE))
from mpi4py import MPI
self._comm_world = MPI.COMM_WORLD
self._comm_local: Optional[MPI.Comm] = None
self._node_rank: Optional[int] = None
self._main_address: Optional[str] = None
self._main_port: Optional[int] = None
@property
def creates_processes_externally(self) -> bool:
return True
@property
def main_address(self) -> str:
if self._main_address is None:
self._main_address = self._get_main_address()
return self._main_address
@property
def main_port(self) -> int:
if self._main_port is None:
self._main_port = self._get_main_port()
return self._main_port
@staticmethod
def detect() -> bool:
"""Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1."""
if not _MPI4PY_AVAILABLE:
return False
from mpi4py import MPI
return MPI.COMM_WORLD.Get_size() > 1
@lru_cache(1)
def world_size(self) -> int:
return self._comm_world.Get_size()
def set_world_size(self, size: int) -> None:
log.debug("MPIEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
@lru_cache(1)
def global_rank(self) -> int:
return self._comm_world.Get_rank()
def set_global_rank(self, rank: int) -> None:
log.debug("MPIEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
@lru_cache(1)
def local_rank(self) -> int:
if self._comm_local is None:
self._init_comm_local()
assert self._comm_local is not None
return self._comm_local.Get_rank()
def node_rank(self) -> int:
if self._node_rank is None:
self._init_comm_local()
assert self._node_rank is not None
return self._node_rank
def _get_main_address(self) -> str:
return self._comm_world.bcast(socket.gethostname(), root=0)
def _get_main_port(self) -> int:
return self._comm_world.bcast(find_free_network_port(), root=0)
def _init_comm_local(self) -> None:
hostname = socket.gethostname()
all_hostnames = self._comm_world.gather(hostname, root=0)
# sort all the hostnames, and find unique ones
unique_hosts = np.unique(all_hostnames)
unique_hosts = self._comm_world.bcast(unique_hosts, root=0)
# find the integer for this host in the list of hosts:
self._node_rank = int(np.where(unique_hosts == hostname)[0])
self._comm_local = self._comm_world.Split(color=self._node_rank)

View File

@ -16,6 +16,7 @@ from lightning.fabric.plugins.environments import ( # noqa: F401
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
MPIEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
XLAEnvironment,

View File

@ -25,6 +25,7 @@ from lightning.fabric.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
MPIEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
@ -449,6 +450,7 @@ class AcceleratorConnector:
TorchElasticEnvironment,
KubeflowEnvironment,
LSFEnvironment,
MPIEnvironment,
):
if env_type.detect():
return env_type()
@ -499,6 +501,7 @@ class AcceleratorConnector:
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()
or LSFEnvironment.detect()
or MPIEnvironment.detect()
):
strategy_flag = "ddp"
if strategy_flag == "dp" and self._accelerator_flag == "cpu":

View File

@ -0,0 +1,131 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from unittest import mock
from unittest.mock import MagicMock
import numpy as np
import pytest
import lightning.fabric.plugins.environments.mpi
from lightning.fabric.plugins.environments import MPIEnvironment
def test_dependencies(monkeypatch):
"""Test that the MPI environment requires the `mpi4py` package."""
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
with pytest.raises(ModuleNotFoundError):
MPIEnvironment()
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
with mock.patch.dict("sys.modules", {"mpi4py": MagicMock()}):
MPIEnvironment()
def test_detect(monkeypatch):
"""Test the detection of an MPI environment configuration."""
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False)
assert not MPIEnvironment.detect()
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 0
assert not MPIEnvironment.detect()
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 1
assert not MPIEnvironment.detect()
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 2
assert MPIEnvironment.detect()
@mock.patch.dict(os.environ, {}, clear=True)
def test_default_attributes(monkeypatch):
"""Test the default attributes when no environment variables are set."""
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
env = MPIEnvironment()
assert env._node_rank is None
assert env._main_address is None
assert env._main_port is None
assert env.creates_processes_externally
def test_init_local_comm(monkeypatch):
"""Test that it can determine the node rank and local rank based on the hostnames of all participating
nodes."""
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
hostname_mock = MagicMock()
mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 4
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}), mock.patch("socket.gethostname", hostname_mock):
env = MPIEnvironment()
hostname_mock.return_value = "host1"
env._comm_world.bcast.return_value = np.array(["host1", "host2"])
assert env.node_rank() == 0
env._node_rank = None
hostname_mock.return_value = "host2"
env._comm_world.bcast.return_value = np.array(["host1", "host2"])
assert env.node_rank() == 1
assert env._comm_local is not None
env._comm_local.Get_rank.return_value = 33
assert env.local_rank() == 33
def test_world_comm(monkeypatch):
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
env = MPIEnvironment()
env._comm_world.Get_size.return_value = 8
assert env.world_size() == 8
env._comm_world.Get_rank.return_value = 3
assert env.global_rank() == 3
def test_setters(monkeypatch, caplog):
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}):
env = MPIEnvironment()
# setter should be no-op
with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_global_rank(100)
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"):
env.set_world_size(100)
assert "setting world size is not allowed" in caplog.text