From e79ac21415171d99e29036d5ed9b1a972447a16c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 Nov 2023 00:01:37 +0000 Subject: [PATCH] Add the input_dir in the cache_dir to avoid overlapping downloads (#18960) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/lightning/data/streaming/dataset.py | 8 +++++--- tests/tests_data/streaming/test_client.py | 3 +++ tests/tests_data/streaming/test_dataset.py | 18 ++++++++++++++---- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 8bc75f1249..c014ef346c 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib import os from typing import Any, List, Optional, Union @@ -57,7 +58,7 @@ class StreamingDataset(IterableDataset): input_dir = _resolve_dir(input_dir) # Override the provided input_path - cache_dir = _try_create_cache_dir() + cache_dir = _try_create_cache_dir(input_dir.path) if cache_dir: input_dir.path = cache_dir @@ -157,9 +158,10 @@ class StreamingDataset(IterableDataset): return data -def _try_create_cache_dir() -> Optional[str]: +def _try_create_cache_dir(input_dir: str) -> Optional[str]: if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: return None - cache_dir = os.path.join("/cache/chunks") + hash_object = hashlib.md5(input_dir.encode()) + cache_dir = os.path.join(f"/cache/chunks/{hash_object.hexdigest()}") os.makedirs(cache_dir, exist_ok=True) return cache_dir diff --git a/tests/tests_data/streaming/test_client.py b/tests/tests_data/streaming/test_client.py index c4697c70a0..0b18a2ae98 100644 --- a/tests/tests_data/streaming/test_client.py +++ b/tests/tests_data/streaming/test_client.py @@ -1,6 +1,8 @@ +import sys from time import sleep, time from unittest import mock +import pytest from lightning.data.streaming import client @@ -21,6 +23,7 @@ def test_s3_client_without_cloud_space_id(monkeypatch): boto3.client.assert_called_once() +@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") def test_s3_client_with_cloud_space_id(monkeypatch): boto3 = mock.MagicMock() monkeypatch.setattr(client, "boto3", boto3) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 4dd2af823c..15ed713351 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -20,7 +20,7 @@ from lightning import seed_everything from lightning.data.datasets.env import _DistributedEnv from lightning.data.streaming import Cache from lightning.data.streaming.dataloader import StreamingDataLoader -from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle from lightning.pytorch.demos.boring_classes import RandomDataset @@ -52,9 +52,9 @@ def test_streaming_dataset(tmpdir, monkeypatch): @mock.patch("lightning.data.streaming.dataset.os.makedirs") def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir): # Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call - with pytest.raises(FileNotFoundError, match="`/cache/chunks` doesn't exist"): - StreamingDataset(tmpdir) - makedirs_mock.assert_called_once_with("/cache/chunks", exist_ok=True) + with pytest.raises(FileNotFoundError, match="`/cache/chunks/275876e34cf609db118f3d84b799a790` doesn't exist"): + StreamingDataset("dummy") + makedirs_mock.assert_called_once_with("/cache/chunks/275876e34cf609db118f3d84b799a790", exist_ok=True) @pytest.mark.parametrize("drop_last", [False, True]) @@ -238,3 +238,13 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch): batches.append(batch) assert len(batches) == 10 + + +@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) +@mock.patch("lightning.data.streaming.dataset.os.makedirs") +def test_try_create_cache_dir(makedirs, monkeypatch): + cache_dir_1 = _try_create_cache_dir("") + cache_dir_2 = _try_create_cache_dir("ssdf") + assert cache_dir_1 != cache_dir_2 + assert cache_dir_1 == "/cache/chunks/d41d8cd98f00b204e9800998ecf8427e" + assert len(makedirs._mock_mock_calls) == 2