Create cache dir if it doesn't exist (#18955)
This commit is contained in:
parent
b30f1a995d
commit
c4af18b2c5
|
@ -28,22 +28,6 @@ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
|
|||
from lightning_cloud.resolver import _resolve_dir
|
||||
|
||||
|
||||
def _try_create_cache_dir(create: bool = False) -> Optional[str]:
|
||||
# Get the ids from env variables
|
||||
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
|
||||
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
|
||||
|
||||
if cluster_id is None or project_id is None:
|
||||
return None
|
||||
|
||||
cache_dir = os.path.join("/cache/chunks")
|
||||
|
||||
if create:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
return cache_dir
|
||||
|
||||
|
||||
class StreamingDataset(IterableDataset):
|
||||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
|
||||
|
||||
|
@ -58,9 +42,7 @@ class StreamingDataset(IterableDataset):
|
|||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
|
||||
|
||||
Arguments:
|
||||
name: The name of the optimised dataset.
|
||||
version: The version of the dataset to use.
|
||||
cache_dir: The cache dir where the data would be stored.
|
||||
input_dir: Path to the folder where the input data is stored.
|
||||
item_loader: The logic to load an item from a chunk.
|
||||
shuffle: Whether to shuffle the data.
|
||||
drop_last: If `True`, drops the last items to ensure that
|
||||
|
@ -173,3 +155,11 @@ class StreamingDataset(IterableDataset):
|
|||
self.index += 1
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _try_create_cache_dir() -> Optional[str]:
|
||||
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
|
||||
return None
|
||||
cache_dir = os.path.join("/cache/chunks")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return cache_dir
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -47,6 +48,15 @@ def test_streaming_dataset(tmpdir, monkeypatch):
|
|||
assert len(dataloader) == 408
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
|
||||
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
|
||||
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
|
||||
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
|
||||
with pytest.raises(FileNotFoundError, match="`/cache/chunks` doesn't exist"):
|
||||
StreamingDataset(tmpdir)
|
||||
makedirs_mock.assert_called_once_with("/cache/chunks", exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drop_last", [False, True])
|
||||
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
|
||||
seed_everything(42)
|
||||
|
|
Loading…
Reference in New Issue