diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 47f960b1a9..8bc75f1249 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -28,22 +28,6 @@ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50: from lightning_cloud.resolver import _resolve_dir -def _try_create_cache_dir(create: bool = False) -> Optional[str]: - # Get the ids from env variables - cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) - project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - - if cluster_id is None or project_id is None: - return None - - cache_dir = os.path.join("/cache/chunks") - - if create: - os.makedirs(cache_dir, exist_ok=True) - - return cache_dir - - class StreamingDataset(IterableDataset): """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.""" @@ -58,9 +42,7 @@ class StreamingDataset(IterableDataset): """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. Arguments: - name: The name of the optimised dataset. - version: The version of the dataset to use. - cache_dir: The cache dir where the data would be stored. + input_dir: Path to the folder where the input data is stored. item_loader: The logic to load an item from a chunk. shuffle: Whether to shuffle the data. drop_last: If `True`, drops the last items to ensure that @@ -173,3 +155,11 @@ class StreamingDataset(IterableDataset): self.index += 1 return data + + +def _try_create_cache_dir() -> Optional[str]: + if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: + return None + cache_dir = os.path.join("/cache/chunks") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 3ce2df7064..4dd2af823c 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -12,6 +12,7 @@ # limitations under the License. import os +from unittest import mock import pytest import torch @@ -47,6 +48,15 @@ def test_streaming_dataset(tmpdir, monkeypatch): assert len(dataloader) == 408 +@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) +@mock.patch("lightning.data.streaming.dataset.os.makedirs") +def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir): + # Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call + with pytest.raises(FileNotFoundError, match="`/cache/chunks` doesn't exist"): + StreamingDataset(tmpdir) + makedirs_mock.assert_called_once_with("/cache/chunks", exist_ok=True) + + @pytest.mark.parametrize("drop_last", [False, True]) def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): seed_everything(42)