Ignore `num_nodes` when running MultiNode components locally (#15806)
This commit is contained in:
parent
e150d083c2
commit
a970f090a0
|
@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- `lightning add ssh-key` CLI command has been transitioned to `lightning create ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
|
||||
- `lightning remove ssh-key` CLI command has been transitioned to `lightning delete ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
|
||||
- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import warnings
|
||||
from typing import Any, Type
|
||||
|
||||
from lightning_app import structures
|
||||
from lightning_app.core.flow import LightningFlow
|
||||
from lightning_app.core.work import LightningWork
|
||||
from lightning_app.utilities.cloud import is_running_in_cloud
|
||||
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
|
||||
|
||||
|
||||
|
@ -45,12 +47,21 @@ class MultiNode(LightningFlow):
|
|||
|
||||
Arguments:
|
||||
work_cls: The work to be executed
|
||||
num_nodes: Number of nodes.
|
||||
cloud_compute: The cloud compute object used in the cloud.
|
||||
num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
|
||||
multiple cloud machines.
|
||||
cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
|
||||
running locally.
|
||||
work_args: Arguments to be provided to the work on instantiation.
|
||||
work_kwargs: Keywords arguments to be provided to the work on instantiation.
|
||||
"""
|
||||
super().__init__()
|
||||
if num_nodes > 1 and not is_running_in_cloud():
|
||||
num_nodes = 1
|
||||
warnings.warn(
|
||||
f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
|
||||
" We assume you are debugging and will ignore the `num_nodes` argument."
|
||||
" To run on multiple nodes in the cloud, launch your app with `--cloud`."
|
||||
)
|
||||
self.ws = structures.List(
|
||||
*[
|
||||
work_cls(
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
from re import escape
|
||||
|
||||
import pytest
|
||||
from tests_app.helpers.utils import no_warning_call
|
||||
|
||||
from lightning_app import CloudCompute, LightningWork
|
||||
from lightning_app.components import MultiNode
|
||||
|
||||
|
||||
def test_multi_node_warn_running_locally():
|
||||
class Work(LightningWork):
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
|
||||
MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
|
||||
|
||||
with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
|
||||
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))
|
|
@ -0,0 +1,30 @@
|
|||
import re
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
|
||||
# TODO: Replace with `lightning_utilities.test.warning.no_warning_call`
|
||||
# https://github.com/Lightning-AI/utilities/issues/57
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
yield
|
||||
|
||||
if match is None:
|
||||
try:
|
||||
w = record.pop(expected_warning)
|
||||
except AssertionError:
|
||||
# no warning raised
|
||||
return
|
||||
else:
|
||||
for w in record.list:
|
||||
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
|
||||
raise AssertionError(f"{msg} was raised: {w}")
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from tests_examples_app.public import _PATH_EXAMPLES
|
||||
|
@ -17,7 +18,8 @@ class LightningTestMultiNodeApp(LightningTestApp):
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="flaky")
|
||||
def test_multi_node_example(monkeypatch):
|
||||
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
|
||||
def test_multi_node_example(_, monkeypatch):
|
||||
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
|
||||
command_line = [
|
||||
"app.py",
|
||||
|
@ -50,7 +52,8 @@ class LightningTestMultiNodeWorksApp(LightningTestApp):
|
|||
],
|
||||
)
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="flaky")
|
||||
def test_multi_node_examples(app_name, monkeypatch):
|
||||
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
|
||||
def test_multi_node_examples(_, app_name, monkeypatch):
|
||||
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
|
||||
command_line = [
|
||||
app_name,
|
||||
|
|
Loading…
Reference in New Issue