diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index b96911afa7..4ad07accc0 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -139,6 +139,7 @@ environments KubeflowEnvironment LightningEnvironment LSFEnvironment + MPIEnvironment SLURMEnvironment TorchElasticEnvironment XLAEnvironment diff --git a/docs/source-pytorch/fabric/api/api_reference.rst b/docs/source-pytorch/fabric/api/api_reference.rst index 680b15df65..649d3182f1 100644 --- a/docs/source-pytorch/fabric/api/api_reference.rst +++ b/docs/source-pytorch/fabric/api/api_reference.rst @@ -89,6 +89,7 @@ Environments ~kubeflow.KubeflowEnvironment ~lightning.LightningEnvironment ~lsf.LSFEnvironment + ~mpi.MPIEnvironment ~slurm.SLURMEnvironment ~torchelastic.TorchElasticEnvironment ~xla.XLAEnvironment diff --git a/docs/source-pytorch/fabric/fundamentals/launch.rst b/docs/source-pytorch/fabric/fundamentals/launch.rst index e3e4762745..9a49d9b050 100644 --- a/docs/source-pytorch/fabric/fundamentals/launch.rst +++ b/docs/source-pytorch/fabric/fundamentals/launch.rst @@ -180,12 +180,20 @@ Choose from the following options based on your expertise level and available in .. displayitem:: :header: Bare Bones Cluster - :description: Train across machines on a network. + :description: Train across machines on a network using `torchrun`. :col_css: col-md-4 :button_link: ../guide/multi_node/barebones.html :height: 160 :tag: advanced +.. displayitem:: + :header: Other Cluster Environments + :description: MPI, LSF, Kubeflow + :col_css: col-md-4 + :button_link: ../guide/multi_node/other.html + :height: 160 + :tag: advanced + .. raw:: html diff --git a/docs/source-pytorch/fabric/guide/multi_node/other.rst b/docs/source-pytorch/fabric/guide/multi_node/other.rst new file mode 100644 index 0000000000..6bddf04f1a --- /dev/null +++ b/docs/source-pytorch/fabric/guide/multi_node/other.rst @@ -0,0 +1,66 @@ +:orphan: + +########################## +Other Cluster Environments +########################## + +**Audience**: Users who want to run on a cluster that launches the training script via MPI, LSF, Kubeflow, etc. + +Lightning automates the details behind training on the most common cluster environments. +While :doc:`SLURM <./slurm>` is the most popular choice for on-prem clusters, there are other systems that Lightning can detect automatically. + +Don't have access to an enterprise cluster? Try the :doc:`Lightning cloud <./cloud>`. + + +---- + + +*** +MPI +*** + +`MPI (Message Passing Interface) `_ is a communication system for parallel computing. +There are many implementations available, the most popular among them are `OpenMPI `_ and `MPICH `_. +To support all these, Lightning relies on the `mpi4py package `_: + +.. code-block:: bash + + pip install mpi4py + +If the package is installed and the Python script gets launched by MPI, Fabric will automatically detect it and parse the process information from the environment. +There is nothing you have to change in your code: + +.. code-block:: python + + fabric = Fabric(...) # automatically detects MPI + print(fabric.world_size) # world size provided by MPI + print(fabric.global_rank) # rank provided by MPI + ... + +If you want to bypass the automatic detection, you can explicitly set the MPI environment as a plugin: + +.. code-block:: python + + from lightning.fabric.plugins.environments import MPIEnvironment + + fabric = Fabric(..., plugins=[MPIEnvironment()]) + + +---- + + +*** +LSF +*** + +Coming soon. + + +---- + + +******** +Kubeflow +******** + +Coming soon. diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index acdff8d50d..0def7a49ce 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -36,6 +36,7 @@ from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, LSFEnvironment, + MPIEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) @@ -374,6 +375,7 @@ class _Connector: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, + MPIEnvironment, ): if env_type.detect(): return env_type() @@ -414,6 +416,7 @@ class _Connector: or KubeflowEnvironment.detect() or SLURMEnvironment.detect() or LSFEnvironment.detect() + or MPIEnvironment.detect() ): strategy_flag = "ddp" if strategy_flag == "dp" and self._accelerator_flag == "cpu": diff --git a/src/lightning/fabric/plugins/environments/__init__.py b/src/lightning/fabric/plugins/environments/__init__.py index ae381f5929..e740602c4b 100644 --- a/src/lightning/fabric/plugins/environments/__init__.py +++ b/src/lightning/fabric/plugins/environments/__init__.py @@ -15,6 +15,7 @@ from lightning.fabric.plugins.environments.cluster_environment import ClusterEnv from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment # noqa: F401 from lightning.fabric.plugins.environments.lightning import LightningEnvironment # noqa: F401 from lightning.fabric.plugins.environments.lsf import LSFEnvironment # noqa: F401 +from lightning.fabric.plugins.environments.mpi import MPIEnvironment # noqa: F401 from lightning.fabric.plugins.environments.slurm import SLURMEnvironment # noqa: F401 from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment # noqa: F401 from lightning.fabric.plugins.environments.xla import XLAEnvironment # noqa: F401 diff --git a/src/lightning/fabric/plugins/environments/mpi.py b/src/lightning/fabric/plugins/environments/mpi.py new file mode 100644 index 0000000000..abe441e2f2 --- /dev/null +++ b/src/lightning/fabric/plugins/environments/mpi.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from functools import lru_cache +from typing import Optional + +import numpy as np +from lightning_utilities.core.imports import RequirementCache + +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.plugins.environments.lightning import find_free_network_port + +log = logging.getLogger(__name__) + +_MPI4PY_AVAILABLE = RequirementCache("mpi4py") + + +class MPIEnvironment(ClusterEnvironment): + """An environment for running on clusters with processes created through MPI. + + Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py + """ + + def __init__(self) -> None: + if not _MPI4PY_AVAILABLE: + raise ModuleNotFoundError(str(_MPI4PY_AVAILABLE)) + + from mpi4py import MPI + + self._comm_world = MPI.COMM_WORLD + self._comm_local: Optional[MPI.Comm] = None + self._node_rank: Optional[int] = None + self._main_address: Optional[str] = None + self._main_port: Optional[int] = None + + @property + def creates_processes_externally(self) -> bool: + return True + + @property + def main_address(self) -> str: + if self._main_address is None: + self._main_address = self._get_main_address() + return self._main_address + + @property + def main_port(self) -> int: + if self._main_port is None: + self._main_port = self._get_main_port() + return self._main_port + + @staticmethod + def detect() -> bool: + """Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1.""" + if not _MPI4PY_AVAILABLE: + return False + + from mpi4py import MPI + + return MPI.COMM_WORLD.Get_size() > 1 + + @lru_cache(1) + def world_size(self) -> int: + return self._comm_world.Get_size() + + def set_world_size(self, size: int) -> None: + log.debug("MPIEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + @lru_cache(1) + def global_rank(self) -> int: + return self._comm_world.Get_rank() + + def set_global_rank(self, rank: int) -> None: + log.debug("MPIEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + @lru_cache(1) + def local_rank(self) -> int: + if self._comm_local is None: + self._init_comm_local() + assert self._comm_local is not None + return self._comm_local.Get_rank() + + def node_rank(self) -> int: + if self._node_rank is None: + self._init_comm_local() + assert self._node_rank is not None + return self._node_rank + + def _get_main_address(self) -> str: + return self._comm_world.bcast(socket.gethostname(), root=0) + + def _get_main_port(self) -> int: + return self._comm_world.bcast(find_free_network_port(), root=0) + + def _init_comm_local(self) -> None: + hostname = socket.gethostname() + all_hostnames = self._comm_world.gather(hostname, root=0) + # sort all the hostnames, and find unique ones + unique_hosts = np.unique(all_hostnames) + unique_hosts = self._comm_world.bcast(unique_hosts, root=0) + # find the integer for this host in the list of hosts: + self._node_rank = int(np.where(unique_hosts == hostname)[0]) + self._comm_local = self._comm_world.Split(color=self._node_rank) diff --git a/src/lightning/pytorch/plugins/environments/__init__.py b/src/lightning/pytorch/plugins/environments/__init__.py index 38ad63f754..a8d2cf6919 100644 --- a/src/lightning/pytorch/plugins/environments/__init__.py +++ b/src/lightning/pytorch/plugins/environments/__init__.py @@ -16,6 +16,7 @@ from lightning.fabric.plugins.environments import ( # noqa: F401 KubeflowEnvironment, LightningEnvironment, LSFEnvironment, + MPIEnvironment, SLURMEnvironment, TorchElasticEnvironment, XLAEnvironment, diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index a9910294d5..ff8852eab6 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -25,6 +25,7 @@ from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, LSFEnvironment, + MPIEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) @@ -449,6 +450,7 @@ class AcceleratorConnector: TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment, + MPIEnvironment, ): if env_type.detect(): return env_type() @@ -499,6 +501,7 @@ class AcceleratorConnector: or KubeflowEnvironment.detect() or SLURMEnvironment.detect() or LSFEnvironment.detect() + or MPIEnvironment.detect() ): strategy_flag = "ddp" if strategy_flag == "dp" and self._accelerator_flag == "cpu": diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py new file mode 100644 index 0000000000..6231324e6e --- /dev/null +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + +import lightning.fabric.plugins.environments.mpi +from lightning.fabric.plugins.environments import MPIEnvironment + + +def test_dependencies(monkeypatch): + """Test that the MPI environment requires the `mpi4py` package.""" + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError): + MPIEnvironment() + + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + with mock.patch.dict("sys.modules", {"mpi4py": MagicMock()}): + MPIEnvironment() + + +def test_detect(monkeypatch): + """Test the detection of an MPI environment configuration.""" + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", False) + assert not MPIEnvironment.detect() + + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + mpi4py_mock = MagicMock() + + with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}): + mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 0 + assert not MPIEnvironment.detect() + + mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 1 + assert not MPIEnvironment.detect() + + mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 2 + assert MPIEnvironment.detect() + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_default_attributes(monkeypatch): + """Test the default attributes when no environment variables are set.""" + + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + mpi4py_mock = MagicMock() + with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}): + env = MPIEnvironment() + + assert env._node_rank is None + assert env._main_address is None + assert env._main_port is None + assert env.creates_processes_externally + + +def test_init_local_comm(monkeypatch): + """Test that it can determine the node rank and local rank based on the hostnames of all participating + nodes.""" + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + mpi4py_mock = MagicMock() + hostname_mock = MagicMock() + + mpi4py_mock.MPI.COMM_WORLD.Get_size.return_value = 4 + with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}), mock.patch("socket.gethostname", hostname_mock): + env = MPIEnvironment() + + hostname_mock.return_value = "host1" + env._comm_world.bcast.return_value = np.array(["host1", "host2"]) + assert env.node_rank() == 0 + + env._node_rank = None + hostname_mock.return_value = "host2" + env._comm_world.bcast.return_value = np.array(["host1", "host2"]) + assert env.node_rank() == 1 + + assert env._comm_local is not None + env._comm_local.Get_rank.return_value = 33 + assert env.local_rank() == 33 + + +def test_world_comm(monkeypatch): + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + mpi4py_mock = MagicMock() + + with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}): + env = MPIEnvironment() + + env._comm_world.Get_size.return_value = 8 + assert env.world_size() == 8 + env._comm_world.Get_rank.return_value = 3 + assert env.global_rank() == 3 + + +def test_setters(monkeypatch, caplog): + # pretend mpi4py is available + monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) + mpi4py_mock = MagicMock() + + with mock.patch.dict("sys.modules", {"mpi4py": mpi4py_mock}): + env = MPIEnvironment() + + # setter should be no-op + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + env.set_global_rank(100) + assert "setting global rank is not allowed" in caplog.text + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="lightning.fabric.plugins.environments"): + env.set_world_size(100) + assert "setting world size is not allowed" in caplog.text