Add kubeflow cluster environment (#7300)
* Add kubeflow cluster environment * Add KubeflowEnvironment to docs * Add KubeflowEnvironment to the changelog * break up a long line * Add method to detect kubeflow environment * Select Kubeflow environment when available * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit * task_idx == 0 Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6e6e29af49
commit
f4f51e0dcf
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
|
||||||
|
|
||||||
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
|
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
|
||||||
|
|
||||||
|
|
|
@ -125,6 +125,7 @@ Cluster Environments
|
||||||
ClusterEnvironment
|
ClusterEnvironment
|
||||||
LightningEnvironment
|
LightningEnvironment
|
||||||
TorchElasticEnvironment
|
TorchElasticEnvironment
|
||||||
|
KubeflowEnvironment
|
||||||
SLURMEnvironment
|
SLURMEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -151,4 +151,5 @@ Cluster Environments
|
||||||
ClusterEnvironment
|
ClusterEnvironment
|
||||||
LightningEnvironment
|
LightningEnvironment
|
||||||
TorchElasticEnvironment
|
TorchElasticEnvironment
|
||||||
|
KubeflowEnvironment
|
||||||
SLURMEnvironment
|
SLURMEnvironment
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
|
||||||
|
from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
|
||||||
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
|
from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
|
||||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
|
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
|
||||||
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
|
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KubeflowEnvironment(ClusterEnvironment):
|
||||||
|
"""
|
||||||
|
Environment for distributed training using the
|
||||||
|
`PyTorchJob <https://www.kubeflow.org/docs/components/training/pytorch/>`_
|
||||||
|
operator from `Kubeflow <https://www.kubeflow.org>`_
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_using_kubeflow() -> bool:
|
||||||
|
""" Returns ``True`` if the current process was launched using Kubeflow PyTorchJob. """
|
||||||
|
required_env_vars = ("KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK")
|
||||||
|
# torchelastic sets these. Make sure we're not in torchelastic
|
||||||
|
excluded_env_vars = ("GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
|
||||||
|
return (all(v in os.environ for v in required_env_vars) and not any(v in os.environ for v in excluded_env_vars))
|
||||||
|
|
||||||
|
def creates_children(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def master_address(self) -> str:
|
||||||
|
return os.environ['MASTER_ADDR']
|
||||||
|
|
||||||
|
def master_port(self) -> int:
|
||||||
|
return int(os.environ['MASTER_PORT'])
|
||||||
|
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return int(os.environ['WORLD_SIZE'])
|
||||||
|
|
||||||
|
def set_world_size(self, size: int) -> None:
|
||||||
|
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
||||||
|
|
||||||
|
def global_rank(self) -> int:
|
||||||
|
return int(os.environ["RANK"])
|
||||||
|
|
||||||
|
def set_global_rank(self, rank: int) -> None:
|
||||||
|
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
||||||
|
|
||||||
|
def local_rank(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def node_rank(self) -> int:
|
||||||
|
return self.global_rank()
|
|
@ -46,6 +46,7 @@ from pytorch_lightning.plugins import (
|
||||||
)
|
)
|
||||||
from pytorch_lightning.plugins.environments import (
|
from pytorch_lightning.plugins.environments import (
|
||||||
ClusterEnvironment,
|
ClusterEnvironment,
|
||||||
|
KubeflowEnvironment,
|
||||||
LightningEnvironment,
|
LightningEnvironment,
|
||||||
SLURMEnvironment,
|
SLURMEnvironment,
|
||||||
TorchElasticEnvironment,
|
TorchElasticEnvironment,
|
||||||
|
@ -397,10 +398,12 @@ class AcceleratorConnector(object):
|
||||||
elif self.use_ddp:
|
elif self.use_ddp:
|
||||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||||
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
|
use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic()
|
||||||
|
use_kubeflow_ddp = self.use_ddp and KubeflowEnvironment.is_using_kubeflow()
|
||||||
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
|
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
|
||||||
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
|
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
|
||||||
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
|
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
|
||||||
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
|
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic()
|
||||||
|
use_ddp_cpu_kubeflow = use_ddp_cpu_spawn and KubeflowEnvironment.is_using_kubeflow()
|
||||||
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
|
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
|
||||||
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
|
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
|
||||||
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
|
use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN
|
||||||
|
@ -416,7 +419,10 @@ class AcceleratorConnector(object):
|
||||||
ddp_plugin_cls = DDPShardedPlugin
|
ddp_plugin_cls = DDPShardedPlugin
|
||||||
elif use_ddp_sharded_spawn:
|
elif use_ddp_sharded_spawn:
|
||||||
ddp_plugin_cls = DDPSpawnShardedPlugin
|
ddp_plugin_cls = DDPSpawnShardedPlugin
|
||||||
elif use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp:
|
elif (
|
||||||
|
use_ddp_cpu_slurm or use_slurm_ddp or use_ddp_cpu_torch_elastic or use_torchelastic_ddp
|
||||||
|
or use_kubeflow_ddp or use_ddp_cpu_kubeflow
|
||||||
|
):
|
||||||
ddp_plugin_cls = DDPPlugin
|
ddp_plugin_cls = DDPPlugin
|
||||||
elif use_ddp_spawn or use_ddp_cpu_spawn:
|
elif use_ddp_spawn or use_ddp_cpu_spawn:
|
||||||
ddp_plugin_cls = DDPSpawnPlugin
|
ddp_plugin_cls = DDPSpawnPlugin
|
||||||
|
@ -488,6 +494,8 @@ class AcceleratorConnector(object):
|
||||||
env = SLURMEnvironment()
|
env = SLURMEnvironment()
|
||||||
elif TorchElasticEnvironment.is_using_torchelastic():
|
elif TorchElasticEnvironment.is_using_torchelastic():
|
||||||
env = TorchElasticEnvironment()
|
env = TorchElasticEnvironment()
|
||||||
|
elif KubeflowEnvironment.is_using_kubeflow():
|
||||||
|
env = KubeflowEnvironment()
|
||||||
else:
|
else:
|
||||||
env = LightningEnvironment()
|
env = LightningEnvironment()
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -35,7 +35,12 @@ from pytorch_lightning.plugins import (
|
||||||
PrecisionPlugin,
|
PrecisionPlugin,
|
||||||
SingleDevicePlugin,
|
SingleDevicePlugin,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment
|
from pytorch_lightning.plugins.environments import (
|
||||||
|
KubeflowEnvironment,
|
||||||
|
LightningEnvironment,
|
||||||
|
SLURMEnvironment,
|
||||||
|
TorchElasticEnvironment,
|
||||||
|
)
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from tests.helpers.boring_model import BoringModel
|
from tests.helpers.boring_model import BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
@ -272,6 +277,80 @@ def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(min_gpus=1)
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ, {
|
||||||
|
"CUDA_VISIBLE_DEVICES": "0",
|
||||||
|
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||||
|
"MASTER_ADDR": "1.2.3.4",
|
||||||
|
"MASTER_PORT": "500",
|
||||||
|
"WORLD_SIZE": "20",
|
||||||
|
"RANK": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@mock.patch('torch.cuda.device_count', return_value=1)
|
||||||
|
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
|
||||||
|
def test_accelerator_choice_ddp_kubeflow(device_count_mock, setup_distributed_mock):
|
||||||
|
|
||||||
|
class CB(Callback):
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer, pl_module):
|
||||||
|
assert trainer.use_ddp
|
||||||
|
assert isinstance(trainer.accelerator, GPUAccelerator)
|
||||||
|
assert isinstance(trainer.training_type_plugin, DDPPlugin)
|
||||||
|
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
|
||||||
|
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
|
||||||
|
assert trainer.training_type_plugin.task_idx == 0
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
model = BoringModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
accelerator='ddp',
|
||||||
|
gpus=1,
|
||||||
|
callbacks=[CB()],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ, {
|
||||||
|
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||||
|
"MASTER_ADDR": "1.2.3.4",
|
||||||
|
"MASTER_PORT": "500",
|
||||||
|
"WORLD_SIZE": "20",
|
||||||
|
"RANK": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@mock.patch('torch.cuda.device_count', return_value=0)
|
||||||
|
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
|
||||||
|
def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distributed_mock):
|
||||||
|
|
||||||
|
class CB(Callback):
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer, pl_module):
|
||||||
|
assert trainer.use_ddp
|
||||||
|
assert isinstance(trainer.accelerator, CPUAccelerator)
|
||||||
|
assert isinstance(trainer.training_type_plugin, DDPPlugin)
|
||||||
|
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
|
||||||
|
assert trainer.training_type_plugin.cluster_environment.local_rank() == 0
|
||||||
|
assert trainer.training_type_plugin.task_idx == 0
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
model = BoringModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
accelerator='ddp_cpu',
|
||||||
|
num_processes=1,
|
||||||
|
callbacks=[CB()],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(
|
@mock.patch.dict(
|
||||||
os.environ, {
|
os.environ, {
|
||||||
"SLURM_NTASKS": "2",
|
"SLURM_NTASKS": "2",
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pytorch_lightning.plugins.environments import KubeflowEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(os.environ, {})
|
||||||
|
def test_default_attributes():
|
||||||
|
""" Test the default attributes when no environment variables are set. """
|
||||||
|
env = KubeflowEnvironment()
|
||||||
|
assert env.creates_children()
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# MASTER_ADDR is required
|
||||||
|
env.master_address()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# MASTER_PORT is required
|
||||||
|
env.master_port()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# WORLD_SIZE is required
|
||||||
|
env.world_size()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
# RANK is required
|
||||||
|
env.global_rank()
|
||||||
|
assert env.local_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ, {
|
||||||
|
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||||
|
"MASTER_ADDR": "1.2.3.4",
|
||||||
|
"MASTER_PORT": "500",
|
||||||
|
"WORLD_SIZE": "20",
|
||||||
|
"RANK": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_attributes_from_environment_variables(caplog):
|
||||||
|
""" Test that the torchelastic cluster environment takes the attributes from the environment variables. """
|
||||||
|
env = KubeflowEnvironment()
|
||||||
|
assert env.master_address() == "1.2.3.4"
|
||||||
|
assert env.master_port() == 500
|
||||||
|
assert env.world_size() == 20
|
||||||
|
assert env.global_rank() == 1
|
||||||
|
assert env.local_rank() == 0
|
||||||
|
assert env.node_rank() == 1
|
||||||
|
# setter should be no-op
|
||||||
|
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
|
||||||
|
env.set_global_rank(100)
|
||||||
|
assert env.global_rank() == 1
|
||||||
|
assert "setting global rank is not allowed" in caplog.text
|
||||||
|
|
||||||
|
caplog.clear()
|
||||||
|
|
||||||
|
with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"):
|
||||||
|
env.set_world_size(100)
|
||||||
|
assert env.world_size() == 20
|
||||||
|
assert "setting world size is not allowed" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ, {
|
||||||
|
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||||
|
"MASTER_ADDR": "1.2.3.4",
|
||||||
|
"MASTER_PORT": "500",
|
||||||
|
"WORLD_SIZE": "20",
|
||||||
|
"RANK": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_is_using_kubeflow():
|
||||||
|
assert KubeflowEnvironment.is_using_kubeflow()
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ, {
|
||||||
|
"KUBERNETES_PORT": "tcp://127.0.0.1:443",
|
||||||
|
"MASTER_ADDR": "1.2.3.4",
|
||||||
|
"MASTER_PORT": "500",
|
||||||
|
"WORLD_SIZE": "20",
|
||||||
|
"RANK": "1",
|
||||||
|
"GROUP_RANK": "1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_is_using_kubeflow_torchelastic():
|
||||||
|
assert not KubeflowEnvironment.is_using_kubeflow()
|
Loading…
Reference in New Issue