diff --git a/src/lightning/data/backends.py b/src/lightning/data/backends.py index 9cdc764185..1d4dbcccd4 100644 --- a/src/lightning/data/backends.py +++ b/src/lightning/data/backends.py @@ -47,7 +47,9 @@ class S3DatasetBackend: if os.getenv("AWS_ACCESS_KEY") and os.getenv("AWS_SECRET_ACCESS_KEY"): return {"access_key": os.getenv("AWS_ACCESS_KEY"), "secret_key": os.getenv("AWS_SECRET_ACCESS_KEY")} - return self.get_aws_credentials() + aws_creds = self.get_aws_credentials() + + return {"access_key": aws_creds.access_key, "secret_key": aws_creds.secret_key, "token": aws_creds.token} def handle_error(self, exc: Exception) -> None: from botocore.exceptions import NoCredentialsError diff --git a/tests/tests_data/test_backends.py b/tests/tests_data/test_backends.py new file mode 100644 index 0000000000..d15f666f43 --- /dev/null +++ b/tests/tests_data/test_backends.py @@ -0,0 +1,33 @@ +import os +from collections import namedtuple +from typing import Mapping +from unittest import mock + + +def test_s3_dataset_backend_credentials_env_vars(): + from lightning.data.backends import S3DatasetBackend + + os.environ["AWS_ACCESS_KEY"] = "123" + os.environ["AWS_SECRET_ACCESS_KEY"] = "abc" + + assert S3DatasetBackend().credentials() == {"access_key": "123", "secret_key": "abc"} + os.environ.pop("AWS_ACCESS_KEY") + os.environ.pop("AWS_SECRET_ACCESS_KEY") + + +_Credentials = namedtuple("RefreshableCredentials", ("access_key", "secret_key", "token")) + + +@mock.patch("botocore.credentials.InstanceMetadataProvider.load", return_value=_Credentials("abc", "def", "ghi")) +def test_s3_dataset_backend_credentials_iam(patch1): + from lightning.data.backends import S3DatasetBackend + + credentials = S3DatasetBackend().credentials() + assert isinstance(credentials, Mapping) + assert credentials == {"access_key": "abc", "secret_key": "def", "token": "ghi"} + + +def test_local_dataset_backend_credentials(): + from lightning.data.backends import LocalDatasetBackend + + assert LocalDatasetBackend().credentials() == {}