diff --git a/src/lightning/data/streaming/client.py b/src/lightning/data/streaming/client.py index 4a57254885..913180ff58 100644 --- a/src/lightning/data/streaming/client.py +++ b/src/lightning/data/streaming/client.py @@ -7,6 +7,8 @@ from lightning.data.constants import _BOTO3_AVAILABLE if _BOTO3_AVAILABLE: import boto3 import botocore + from botocore.credentials import InstanceMetadataProvider + from botocore.utils import InstanceMetadataFetcher class S3Client: @@ -18,20 +20,34 @@ class S3Client: self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ self._client: Optional[Any] = None + def _create_client(self) -> None: + has_shared_credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # noqa: E501 + + if has_shared_credentials_file: + self._client = boto3.client( + "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) + ) + else: + provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) + credentials = provider.load() + self._client = boto3.client( + "s3", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, + config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), + ) + @property def client(self) -> Any: if not self._has_cloud_space_id: if self._client is None: - self._client = boto3.client( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) - ) + self._create_client() return self._client # Re-generate credentials for EC2 if self._last_time is None or (time() - self._last_time) > self._refetch_interval: - self._client = boto3.client( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) - ) + self._create_client() self._last_time = time() return self._client diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index 03d7b93020..c3e0d17044 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -57,7 +57,7 @@ class S3Downloader(Downloader): return try: - with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0): + with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0): if self._s5cmd_available: proc = subprocess.Popen( f"s5cmd cp {remote_filepath} {local_filepath}", diff --git a/tests/tests_data/streaming/test_client.py b/tests/tests_data/streaming/test_client.py index e4d9d80cbd..d1425fd9c2 100644 --- a/tests/tests_data/streaming/test_client.py +++ b/tests/tests_data/streaming/test_client.py @@ -13,6 +13,12 @@ def test_s3_client_without_cloud_space_id(monkeypatch): botocore = mock.MagicMock() monkeypatch.setattr(client, "botocore", botocore) + instance_metadata_provider = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) + + instance_metadata_fetcher = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) + s3 = client.S3Client(1) assert s3.client assert s3.client @@ -24,7 +30,8 @@ def test_s3_client_without_cloud_space_id(monkeypatch): @pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") -def test_s3_client_with_cloud_space_id(monkeypatch): +@pytest.mark.parametrize("use_shared_credentials", [False, True]) +def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch): boto3 = mock.MagicMock() monkeypatch.setattr(client, "boto3", boto3) @@ -33,6 +40,16 @@ def test_s3_client_with_cloud_space_id(monkeypatch): monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") + if use_shared_credentials: + monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials") + monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials") + + instance_metadata_provider = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) + + instance_metadata_fetcher = mock.MagicMock() + monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) + s3 = client.S3Client(1) assert s3.client assert s3.client @@ -45,3 +62,5 @@ def test_s3_client_with_cloud_space_id(monkeypatch): assert s3.client assert s3.client assert len(boto3.client._mock_mock_calls) == 9 + + assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3