Add kubeflow cluster environment (#7300)

* Add kubeflow cluster environment

* Add KubeflowEnvironment to docs

* Add KubeflowEnvironment to the changelog

* break up a long line

* Add method to detect kubeflow environment

* Select Kubeflow environment when available

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Run pre-commit

* task_idx == 0

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Nic Eggert 2021-05-17 03:05:24 -05:00 committed by GitHub
parent 6e6e29af49
commit f4f51e0dcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 243 additions and 2 deletions

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))

View File

@ -125,6 +125,7 @@ Cluster Environments
ClusterEnvironment
LightningEnvironment
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment

View File

@ -151,4 +151,5 @@ Cluster Environments
ClusterEnvironment
LightningEnvironment
TorchElasticEnvironment
KubeflowEnvironment
SLURMEnvironment

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401

View File

@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
log = logging.getLogger(__name__)
class KubeflowEnvironment(ClusterEnvironment):
"""
Environment for distributed training using the
`PyTorchJob <https://www.kubeflow.org/docs/components/training/pytorch/>`_
operator from `Kubeflow <https://www.kubeflow.org>`_
"""
@staticmethod
def is_using_kubeflow() -> bool:
""" Returns ``True`` if the current process was launched using Kubeflow PyTorchJob. """
required_env_vars = ("KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK")
# torchelastic sets these. Make sure we're not in torchelastic
excluded_env_vars = ("GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return (all(v in os.environ for v in required_env_vars) and not any(v in os.environ for v in excluded_env_vars))
def creates_children(self) -> bool:
return True
def master_address(self) -> str:
return os.environ['MASTER_ADDR']
def master_port(self) -> int:
return int(os.environ['MASTER_PORT'])
def world_size(self) -> int:
return int(os.environ['WORLD_SIZE'])
def set_world_size(self, size: int) -> None:
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
def global_rank(self) -> int:
return int(os.environ["RANK"])
def set_global_rank(self, rank: int) -> None:
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
def local_rank(self) -> int:
return 0
def node_rank(self) -> int:
return self.global_rank()

View File

@ -46,6 +46,7 @@ from pytorch_lightning.plugins import (
)
from pytorch_lightning.plugins.environments import (
ClusterEnvironment,
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
@ -397,10 +398,12 @@ class AcceleratorConnector(object):
elif self.use_ddp:
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
@ -416,7 +419,10 @@ class AcceleratorConnector(object):
ddp_plugin_cls = DDPShardedPlugin
elif use_ddp_sharded_spawn:
ddp_plugin_cls = DDPSpawnShardedPlugin
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
elif (
use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
or use_kubeflow_ddp or use_ddp_cpu_kubeflow
):
ddp_plugin_cls = DDPPlugin
elif use_ddp_spawn or use_ddp_cpu_spawn:
ddp_plugin_cls = DDPSpawnPlugin
@ -488,6 +494,8 @@ class AcceleratorConnector(object):
env = SLURMEnvironment()
elif TorchElasticEnvironment.is_using_torchelastic():
env = TorchElasticEnvironment()
elif KubeflowEnvironment.is_using_kubeflow():
env = KubeflowEnvironment()
else:
env = LightningEnvironment()
return env

View File

@ -35,7 +35,12 @@ from pytorch_lightning.plugins import (
PrecisionPlugin,
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -272,6 +277,80 @@ def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock
trainer.fit(model)
@RunIf(min_gpus=1)
@mock.patch.dict(
os.environ, {
"CUDA_VISIBLE_DEVICES": "0",
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
"MASTER_ADDR": "1.2.3.4",
"MASTER_PORT": "500",
"WORLD_SIZE": "20",
"RANK": "1",
}
)
@mock.patch('torch.cuda.device_count', return_value=1)
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_kubeflow(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
assert trainer.training_type_plugin.task_idx == 0
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
accelerator='ddp',
gpus=1,
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
@mock.patch.dict(
os.environ, {
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
"MASTER_ADDR": "1.2.3.4",
"MASTER_PORT": "500",
"WORLD_SIZE": "20",
"RANK": "1",
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distributed_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
assert trainer.training_type_plugin.task_idx == 0
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
accelerator='ddp_cpu',
num_processes=1,
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
@mock.patch.dict(
os.environ, {
"SLURM_NTASKS": "2",

View File

@ -0,0 +1,87 @@
import logging
import os
from unittest import mock
import pytest
from pytorch_lightning.plugins.environments import KubeflowEnvironment
@mock.patch.dict(os.environ, {})
def test_default_attributes():
""" Test the default attributes when no environment variables are set. """
env = KubeflowEnvironment()
assert env.creates_children()
with pytest.raises(KeyError):
# MASTER_ADDR is required
env.master_address()
with pytest.raises(KeyError):
# MASTER_PORT is required
env.master_port()
with pytest.raises(KeyError):
# WORLD_SIZE is required
env.world_size()
with pytest.raises(KeyError):
# RANK is required
env.global_rank()
assert env.local_rank() == 0
@mock.patch.dict(
os.environ, {
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
"MASTER_ADDR": "1.2.3.4",
"MASTER_PORT": "500",
"WORLD_SIZE": "20",
"RANK": "1",
}
)
def test_attributes_from_environment_variables(caplog):
""" Test that the torchelastic cluster environment takes the attributes from the environment variables. """
env = KubeflowEnvironment()
assert env.master_address() == "1.2.3.4"
assert env.master_port() == 500
assert env.world_size() == 20
assert env.global_rank() == 1
assert env.local_rank() == 0
assert env.node_rank() == 1
# setter should be no-op
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
env.set_global_rank(100)
assert env.global_rank() == 1
assert "setting global rank is not allowed" in caplog.text
caplog.clear()
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
env.set_world_size(100)
assert env.world_size() == 20
assert "setting world size is not allowed" in caplog.text
@mock.patch.dict(
os.environ, {
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
"MASTER_ADDR": "1.2.3.4",
"MASTER_PORT": "500",
"WORLD_SIZE": "20",
"RANK": "1",
}
)
def test_is_using_kubeflow():
assert KubeflowEnvironment.is_using_kubeflow()
@mock.patch.dict(
os.environ, {
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
"MASTER_ADDR": "1.2.3.4",
"MASTER_PORT": "500",
"WORLD_SIZE": "20",
"RANK": "1",
"GROUP_RANK": "1",
}
)
def test_is_using_kubeflow_torchelastic():
assert not KubeflowEnvironment.is_using_kubeflow()