Expose public and private IP in LightningWork (#17742)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-06-06 13:35:08 +02:00 committed by GitHub
parent a901571fdf
commit d23c772f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 18 deletions

View File

@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allow customize `gradio` components with lightning colors ([#17054](https://github.com/Lightning-AI/lightning/pull/17054))
- Added the property `LightningWork.public_ip` that exposes the public IP of the `LightningWork` instance ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
### Changed
@ -29,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
-
- Fixed `LightningWork.internal_ip` that was mistakenly exposing the public IP instead; now exposes the private/internal IP address ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
## [2.0.1.post0] - 2023-04-11

View File

@ -231,9 +231,10 @@ class Database(LightningWork):
use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
if use_localhost:
return self.url
if self.internal_ip != "":
return f"http://{self.internal_ip}:{self.port}"
return self.internal_ip
ip_addr = self.public_ip or self.internal_ip
if ip_addr != "":
return f"http://{ip_addr}:{self.port}"
return ip_addr
def on_exit(self):
self._exit_event.set()

View File

@ -180,9 +180,9 @@ class _LoadBalancer(LightningWork):
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
def get_internal_url(self) -> str:
if not self._internal_ip:
raise ValueError("Internal IP not set")
return f"http://{self._internal_ip}:{self._port}"
if not self._public_ip:
raise ValueError("Public IP not set")
return f"http://{self._public_ip}:{self._port}"
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
@ -386,7 +386,7 @@ class _LoadBalancer(LightningWork):
"""
old_server_urls = set(self.servers)
current_server_urls = {
f"http://{server._internal_ip}:{server.port}" for server in server_works if server._internal_ip
f"http://{server._public_ip}:{server.port}" for server in server_works if server._internal_ip
}
# doing nothing if no server work has been added/removed

View File

@ -60,6 +60,7 @@ class LightningWork:
"_url",
"_restarting",
"_internal_ip",
"_public_ip",
)
_run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
@ -138,6 +139,7 @@ class LightningWork:
"_url",
"_future_url",
"_internal_ip",
"_public_ip",
"_restarting",
"_cloud_compute",
"_display_name",
@ -148,6 +150,7 @@ class LightningWork:
self._url: str = ""
self._future_url: str = "" # The cache URL is meant to defer resolving the url values.
self._internal_ip: str = ""
self._public_ip: str = ""
# setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
self._name: str = ""
@ -212,6 +215,15 @@ class LightningWork:
"""
return self._internal_ip
@property
def public_ip(self) -> str:
"""The public ip address of this LightningWork, reachable from the internet.
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
"""
return self._public_ip
def _on_init_end(self) -> None:
self._local_build_config.on_work_init(self)
self._cloud_build_config.on_work_init(self, self._cloud_compute)

View File

@ -494,7 +494,8 @@ class WorkRunner:
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip)
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_PRIVATE_IP", default_internal_ip)
self.work._public_ip = os.environ.get("LIGHTNING_NODE_IP", "")
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.

View File

@ -119,6 +119,7 @@ def test_simple_app(tmpdir):
"_url": "",
"_future_url": "",
"_internal_ip": "",
"_public_ip": "",
"_paths": {},
"_port": None,
"_restarting": False,
@ -136,6 +137,7 @@ def test_simple_app(tmpdir):
"_url": "",
"_future_url": "",
"_internal_ip": "",
"_public_ip": "",
"_paths": {},
"_port": None,
"_restarting": False,
@ -982,7 +984,7 @@ class SizeFlow(LightningFlow):
def test_state_size_constant_growth():
app = LightningApp(SizeFlow())
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.root._state_sizes[0] <= 7965
assert app.root._state_sizes[0] <= 8304
assert app.root._state_sizes[20] <= 26550

View File

@ -324,6 +324,7 @@ def test_lightning_flow_and_work():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@ -349,6 +350,7 @@ def test_lightning_flow_and_work():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@ -388,6 +390,7 @@ def test_lightning_flow_and_work():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@ -413,6 +416,7 @@ def test_lightning_flow_and_work():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",

View File

@ -46,6 +46,7 @@ def test_dict():
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
"_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@ -80,6 +81,7 @@ def test_dict():
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
"_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@ -114,6 +116,7 @@ def test_dict():
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
"_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@ -199,6 +202,7 @@ def test_list():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@ -233,6 +237,7 @@ def test_list():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@ -262,6 +267,7 @@ def test_list():
"_paths": {},
"_restarting": False,
"_internal_ip": "",
"_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",

View File

@ -641,16 +641,21 @@ def test_state_observer():
@pytest.mark.parametrize(
("patch_constants", "environment", "expected_ip_addr"),
("patch_constants", "environment", "expected_public_ip", "expected_private_ip"),
[
({}, {}, "127.0.0.1"),
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), # noqa: S104
({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"),
({}, {}, "", "127.0.0.1"),
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "", "0.0.0.0"), # noqa: S104
(
{},
{"LIGHTNING_NODE_IP": "85.44.2.25", "LIGHTNING_NODE_PRIVATE_IP": "10.10.10.5"},
"85.44.2.25",
"10.10.10.5",
),
],
indirect=["patch_constants"],
)
def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr):
"""Test that the WorkRunner updates the internal ip address as soon as the Work starts running."""
def test_work_runner_sets_public_and_private_ip(patch_constants, environment, expected_public_ip, expected_private_ip):
"""Test that the WorkRunner updates the public and private address as soon as the Work starts running."""
class Work(LightningWork):
def run(self):
@ -690,11 +695,13 @@ def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_
with mock.patch.dict(os.environ, environment, clear=True):
work_runner.setup()
# The internal ip address only becomes available once the hardware is up / the work is running.
# The public ip address only becomes available once the hardware is up / the work is running.
assert work.public_ip == ""
assert work.internal_ip == ""
with contextlib.suppress(Empty):
work_runner.run_once()
assert work.internal_ip == expected_ip_addr
assert work.public_ip == expected_public_ip
assert work.internal_ip == expected_private_ip
class WorkBi(LightningWork):