Ignore `num_nodes` when running MultiNode components locally (#15806)

This commit is contained in:
Adrian Wälchli 2022-11-24 18:21:32 +01:00 committed by GitHub
parent e150d083c2
commit a970f090a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 4 deletions

View File

@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `lightning add ssh-key` CLI command has been transitioned to `lightning create ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
- `lightning remove ssh-key` CLI command has been transitioned to `lightning delete ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806))
### Deprecated

View File

@ -1,8 +1,10 @@
import warnings
from typing import Any, Type
from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
@ -45,12 +47,21 @@ class MultiNode(LightningFlow):
Arguments:
work_cls: The work to be executed
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
multiple cloud machines.
cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
running locally.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():
num_nodes = 1
warnings.warn(
f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
" We assume you are debugging and will ignore the `num_nodes` argument."
" To run on multiple nodes in the cloud, launch your app with `--cloud`."
)
self.ws = structures.List(
*[
work_cls(

View File

@ -0,0 +1,19 @@
from re import escape
import pytest
from tests_app.helpers.utils import no_warning_call
from lightning_app import CloudCompute, LightningWork
from lightning_app.components import MultiNode
def test_multi_node_warn_running_locally():
class Work(LightningWork):
def run(self):
pass
with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))

View File

View File

@ -0,0 +1,30 @@
import re
from contextlib import contextmanager
from typing import Optional, Type
import pytest
@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
# TODO: Replace with `lightning_utilities.test.warning.no_warning_call`
# https://github.com/Lightning-AI/utilities/issues/57
with pytest.warns(None) as record:
yield
if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")

View File

@ -1,5 +1,6 @@
import os
import sys
from unittest import mock
import pytest
from tests_examples_app.public import _PATH_EXAMPLES
@ -17,7 +18,8 @@ class LightningTestMultiNodeApp(LightningTestApp):
@pytest.mark.skip(reason="flaky")
def test_multi_node_example(monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_example(_, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
"app.py",
@ -50,7 +52,8 @@ class LightningTestMultiNodeWorksApp(LightningTestApp):
],
)
@pytest.mark.skipif(sys.platform == "win32", reason="flaky")
def test_multi_node_examples(app_name, monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_examples(_, app_name, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
app_name,