Create cache dir if it doesn't exist (#18955)

This commit is contained in:
Adrian Wälchli 2023-11-06 17:02:05 +01:00 committed by GitHub
parent b30f1a995d
commit c4af18b2c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 19 deletions

View File

@ -28,22 +28,6 @@ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
from lightning_cloud.resolver import _resolve_dir
def _try_create_cache_dir(create: bool = False) -> Optional[str]:
# Get the ids from env variables
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
if cluster_id is None or project_id is None:
return None
cache_dir = os.path.join("/cache/chunks")
if create:
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
@ -58,9 +42,7 @@ class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Arguments:
name: The name of the optimised dataset.
version: The version of the dataset to use.
cache_dir: The cache dir where the data would be stored.
input_dir: Path to the folder where the input data is stored.
item_loader: The logic to load an item from a chunk.
shuffle: Whether to shuffle the data.
drop_last: If `True`, drops the last items to ensure that
@ -173,3 +155,11 @@ class StreamingDataset(IterableDataset):
self.index += 1
return data
def _try_create_cache_dir() -> Optional[str]:
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
return None
cache_dir = os.path.join("/cache/chunks")
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

View File

@ -12,6 +12,7 @@
# limitations under the License.
import os
from unittest import mock
import pytest
import torch
@ -47,6 +48,15 @@ def test_streaming_dataset(tmpdir, monkeypatch):
assert len(dataloader) == 408
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
with pytest.raises(FileNotFoundError, match="`/cache/chunks` doesn't exist"):
StreamingDataset(tmpdir)
makedirs_mock.assert_called_once_with("/cache/chunks", exist_ok=True)
@pytest.mark.parametrize("drop_last", [False, True])
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
seed_everything(42)