Add the input_dir in the cache_dir to avoid overlapping downloads (#18960)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
thomas chaton 2023-11-07 00:01:37 +00:00 committed by GitHub
parent 195a3bf5b5
commit e79ac21415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 7 deletions

View File

@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
from typing import Any, List, Optional, Union
@ -57,7 +58,7 @@ class StreamingDataset(IterableDataset):
input_dir = _resolve_dir(input_dir)
# Override the provided input_path
cache_dir = _try_create_cache_dir()
cache_dir = _try_create_cache_dir(input_dir.path)
if cache_dir:
input_dir.path = cache_dir
@ -157,9 +158,10 @@ class StreamingDataset(IterableDataset):
return data
def _try_create_cache_dir() -> Optional[str]:
def _try_create_cache_dir(input_dir: str) -> Optional[str]:
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
return None
cache_dir = os.path.join("/cache/chunks")
hash_object = hashlib.md5(input_dir.encode())
cache_dir = os.path.join(f"/cache/chunks/{hash_object.hexdigest()}")
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

View File

@ -1,6 +1,8 @@
import sys
from time import sleep, time
from unittest import mock
import pytest
from lightning.data.streaming import client
@ -21,6 +23,7 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
boto3.client.assert_called_once()
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
def test_s3_client_with_cloud_space_id(monkeypatch):
boto3 = mock.MagicMock()
monkeypatch.setattr(client, "boto3", boto3)

View File

@ -20,7 +20,7 @@ from lightning import seed_everything
from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.pytorch.demos.boring_classes import RandomDataset
@ -52,9 +52,9 @@ def test_streaming_dataset(tmpdir, monkeypatch):
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock, tmpdir):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
with pytest.raises(FileNotFoundError, match="`/cache/chunks` doesn't exist"):
StreamingDataset(tmpdir)
makedirs_mock.assert_called_once_with("/cache/chunks", exist_ok=True)
with pytest.raises(FileNotFoundError, match="`/cache/chunks/275876e34cf609db118f3d84b799a790` doesn't exist"):
StreamingDataset("dummy")
makedirs_mock.assert_called_once_with("/cache/chunks/275876e34cf609db118f3d84b799a790", exist_ok=True)
@pytest.mark.parametrize("drop_last", [False, True])
@ -238,3 +238,13 @@ def test_streaming_dataset_deepcopy(tmpdir, monkeypatch):
batches.append(batch)
assert len(batches) == 10
@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
def test_try_create_cache_dir(makedirs, monkeypatch):
cache_dir_1 = _try_create_cache_dir("")
cache_dir_2 = _try_create_cache_dir("ssdf")
assert cache_dir_1 != cache_dir_2
assert cache_dir_1 == "/cache/chunks/d41d8cd98f00b204e9800998ecf8427e"
assert len(makedirs._mock_mock_calls) == 2