Add kubeflow cluster environment (#7300)
* Add kubeflow cluster environment * Add KubeflowEnvironment to docs * Add KubeflowEnvironment to the changelog * break up a long line * Add method to detect kubeflow environment * Select Kubeflow environment when available * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit * task_idx == 0 Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6e6e29af49
commit
f4f51e0dcf
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
|
||||
|
||||
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
|
||||
|
||||
|
|
|
@ -125,6 +125,7 @@ Cluster Environments
|
|||
ClusterEnvironment
|
||||
LightningEnvironment
|
||||
TorchElasticEnvironment
|
||||
KubeflowEnvironment
|
||||
SLURMEnvironment
|
||||
|
||||
|
||||
|
|
|
@ -151,4 +151,5 @@ Cluster Environments
|
|||
ClusterEnvironment
|
||||
LightningEnvironment
|
||||
TorchElasticEnvironment
|
||||
KubeflowEnvironment
|
||||
SLURMEnvironment
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
|
||||
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
|
||||
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
|
||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
|
||||
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KubeflowEnvironment(ClusterEnvironment):
|
||||
"""
|
||||
Environment for distributed training using the
|
||||
`PyTorchJob <https://www.kubeflow.org/docs/components/training/pytorch/>`_
|
||||
operator from `Kubeflow <https://www.kubeflow.org>`_
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_using_kubeflow() -> bool:
|
||||
""" Returns ``True`` if the current process was launched using Kubeflow PyTorchJob. """
|
||||
required_env_vars = ("KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK")
|
||||
# torchelastic sets these. Make sure we're not in torchelastic
|
||||
excluded_env_vars = ("GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
|
||||
return (all(v in os.environ for v in required_env_vars) and not any(v in os.environ for v in excluded_env_vars))
|
||||
|
||||
def creates_children(self) -> bool:
|
||||
return True
|
||||
|
||||
def master_address(self) -> str:
|
||||
return os.environ['MASTER_ADDR']
|
||||
|
||||
def master_port(self) -> int:
|
||||
return int(os.environ['MASTER_PORT'])
|
||||
|
||||
def world_size(self) -> int:
|
||||
return int(os.environ['WORLD_SIZE'])
|
||||
|
||||
def set_world_size(self, size: int) -> None:
|
||||
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
||||
|
||||
def global_rank(self) -> int:
|
||||
return int(os.environ["RANK"])
|
||||
|
||||
def set_global_rank(self, rank: int) -> None:
|
||||
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
||||
|
||||
def local_rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def node_rank(self) -> int:
|
||||
return self.global_rank()
|
|
@ -46,6 +46,7 @@ from pytorch_lightning.plugins import (
|
|||
)
|
||||
from pytorch_lightning.plugins.environments import (
|
||||
ClusterEnvironment,
|
||||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
SLURMEnvironment,
|
||||
TorchElasticEnvironment,
|
||||
|
@ -397,10 +398,12 @@ class AcceleratorConnector(object):
|
|||
elif self.use_ddp:
|
||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
|
||||
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
|
||||
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
|
||||
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
|
||||
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
|
||||
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
|
||||
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
|
||||
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
|
||||
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
|
||||
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
|
||||
|
@ -416,7 +419,10 @@ class AcceleratorConnector(object):
|
|||
ddp_plugin_cls = DDPShardedPlugin
|
||||
elif use_ddp_sharded_spawn:
|
||||
ddp_plugin_cls = DDPSpawnShardedPlugin
|
||||
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
|
||||
elif (
|
||||
use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
|
||||
or use_kubeflow_ddp or use_ddp_cpu_kubeflow
|
||||
):
|
||||
ddp_plugin_cls = DDPPlugin
|
||||
elif use_ddp_spawn or use_ddp_cpu_spawn:
|
||||
ddp_plugin_cls = DDPSpawnPlugin
|
||||
|
@ -488,6 +494,8 @@ class AcceleratorConnector(object):
|
|||
env = SLURMEnvironment()
|
||||
elif TorchElasticEnvironment.is_using_torchelastic():
|
||||
env = TorchElasticEnvironment()
|
||||
elif KubeflowEnvironment.is_using_kubeflow():
|
||||
env = KubeflowEnvironment()
|
||||
else:
|
||||
env = LightningEnvironment()
|
||||
return env
|
||||
|
|
|
@ -35,7 +35,12 @@ from pytorch_lightning.plugins import (
|
|||
PrecisionPlugin,
|
||||
SingleDevicePlugin,
|
||||
)
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment
|
||||
from pytorch_lightning.plugins.environments import (
|
||||
KubeflowEnvironment,
|
||||
LightningEnvironment,
|
||||
SLURMEnvironment,
|
||||
TorchElasticEnvironment,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -272,6 +277,80 @@ def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"CUDA_VISIBLE_DEVICES": "0",
|
||||
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||
"MASTER_ADDR": "1.2.3.4",
|
||||
"MASTER_PORT": "500",
|
||||
"WORLD_SIZE": "20",
|
||||
"RANK": "1",
|
||||
}
|
||||
)
|
||||
@mock.patch('torch.cuda.device_count', return_value=1)
|
||||
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
|
||||
def test_accelerator_choice_ddp_kubeflow(device_count_mock, setup_distributed_mock):
|
||||
|
||||
class CB(Callback):
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
assert trainer.use_ddp
|
||||
assert isinstance(trainer.accelerator, GPUAccelerator)
|
||||
assert isinstance(trainer.training_type_plugin, DDPPlugin)
|
||||
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
|
||||
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
|
||||
assert trainer.training_type_plugin.task_idx == 0
|
||||
raise SystemExit()
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
accelerator='ddp',
|
||||
gpus=1,
|
||||
callbacks=[CB()],
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||
"MASTER_ADDR": "1.2.3.4",
|
||||
"MASTER_PORT": "500",
|
||||
"WORLD_SIZE": "20",
|
||||
"RANK": "1",
|
||||
}
|
||||
)
|
||||
@mock.patch('torch.cuda.device_count', return_value=0)
|
||||
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
|
||||
def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distributed_mock):
|
||||
|
||||
class CB(Callback):
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
assert trainer.use_ddp
|
||||
assert isinstance(trainer.accelerator, CPUAccelerator)
|
||||
assert isinstance(trainer.training_type_plugin, DDPPlugin)
|
||||
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
|
||||
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
|
||||
assert trainer.training_type_plugin.task_idx == 0
|
||||
raise SystemExit()
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
accelerator='ddp_cpu',
|
||||
num_processes=1,
|
||||
callbacks=[CB()],
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"SLURM_NTASKS": "2",
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.plugins.environments import KubeflowEnvironment
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {})
|
||||
def test_default_attributes():
|
||||
""" Test the default attributes when no environment variables are set. """
|
||||
env = KubeflowEnvironment()
|
||||
assert env.creates_children()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# MASTER_ADDR is required
|
||||
env.master_address()
|
||||
with pytest.raises(KeyError):
|
||||
# MASTER_PORT is required
|
||||
env.master_port()
|
||||
with pytest.raises(KeyError):
|
||||
# WORLD_SIZE is required
|
||||
env.world_size()
|
||||
with pytest.raises(KeyError):
|
||||
# RANK is required
|
||||
env.global_rank()
|
||||
assert env.local_rank() == 0
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||
"MASTER_ADDR": "1.2.3.4",
|
||||
"MASTER_PORT": "500",
|
||||
"WORLD_SIZE": "20",
|
||||
"RANK": "1",
|
||||
}
|
||||
)
|
||||
def test_attributes_from_environment_variables(caplog):
|
||||
""" Test that the torchelastic cluster environment takes the attributes from the environment variables. """
|
||||
env = KubeflowEnvironment()
|
||||
assert env.master_address() == "1.2.3.4"
|
||||
assert env.master_port() == 500
|
||||
assert env.world_size() == 20
|
||||
assert env.global_rank() == 1
|
||||
assert env.local_rank() == 0
|
||||
assert env.node_rank() == 1
|
||||
# setter should be no-op
|
||||
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
|
||||
env.set_global_rank(100)
|
||||
assert env.global_rank() == 1
|
||||
assert "setting global rank is not allowed" in caplog.text
|
||||
|
||||
caplog.clear()
|
||||
|
||||
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
|
||||
env.set_world_size(100)
|
||||
assert env.world_size() == 20
|
||||
assert "setting world size is not allowed" in caplog.text
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||
"MASTER_ADDR": "1.2.3.4",
|
||||
"MASTER_PORT": "500",
|
||||
"WORLD_SIZE": "20",
|
||||
"RANK": "1",
|
||||
}
|
||||
)
|
||||
def test_is_using_kubeflow():
|
||||
assert KubeflowEnvironment.is_using_kubeflow()
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ, {
|
||||
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||
"MASTER_ADDR": "1.2.3.4",
|
||||
"MASTER_PORT": "500",
|
||||
"WORLD_SIZE": "20",
|
||||
"RANK": "1",
|
||||
"GROUP_RANK": "1",
|
||||
}
|
||||
)
|
||||
def test_is_using_kubeflow_torchelastic():
|
||||
assert not KubeflowEnvironment.is_using_kubeflow()
|
Loading…
Reference in New Issue