Resolve boto3 unable to local credentials (#19472)

This commit is contained in:
thomas chaton 2024-02-14 10:54:01 +00:00 committed by GitHub
parent d61f6fecd4
commit 71f44775c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 8 deletions

View File

@ -7,6 +7,8 @@ from lightning.data.constants import _BOTO3_AVAILABLE
if _BOTO3_AVAILABLE: if _BOTO3_AVAILABLE:
import boto3 import boto3
import botocore import botocore
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
class S3Client: class S3Client:
@ -18,20 +20,34 @@ class S3Client:
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
self._client: Optional[Any] = None self._client: Optional[Any] = None
def _create_client(self) -> None:
has_shared_credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # noqa: E501
if has_shared_credentials_file:
self._client = boto3.client(
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)
else:
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
credentials = provider.load()
self._client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
)
@property @property
def client(self) -> Any: def client(self) -> Any:
if not self._has_cloud_space_id: if not self._has_cloud_space_id:
if self._client is None: if self._client is None:
self._client = boto3.client( self._create_client()
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)
return self._client return self._client
# Re-generate credentials for EC2 # Re-generate credentials for EC2
if self._last_time is None or (time() - self._last_time) > self._refetch_interval: if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
self._client = boto3.client( self._create_client()
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)
self._last_time = time() self._last_time = time()
return self._client return self._client

View File

@ -57,7 +57,7 @@ class S3Downloader(Downloader):
return return
try: try:
with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0): with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
if self._s5cmd_available: if self._s5cmd_available:
proc = subprocess.Popen( proc = subprocess.Popen(
f"s5cmd cp {remote_filepath} {local_filepath}", f"s5cmd cp {remote_filepath} {local_filepath}",

View File

@ -13,6 +13,12 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
botocore = mock.MagicMock() botocore = mock.MagicMock()
monkeypatch.setattr(client, "botocore", botocore) monkeypatch.setattr(client, "botocore", botocore)
instance_metadata_provider = mock.MagicMock()
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
instance_metadata_fetcher = mock.MagicMock()
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
s3 = client.S3Client(1) s3 = client.S3Client(1)
assert s3.client assert s3.client
assert s3.client assert s3.client
@ -24,7 +30,8 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows") @pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
def test_s3_client_with_cloud_space_id(monkeypatch): @pytest.mark.parametrize("use_shared_credentials", [False, True])
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
boto3 = mock.MagicMock() boto3 = mock.MagicMock()
monkeypatch.setattr(client, "boto3", boto3) monkeypatch.setattr(client, "boto3", boto3)
@ -33,6 +40,16 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
if use_shared_credentials:
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")
instance_metadata_provider = mock.MagicMock()
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
instance_metadata_fetcher = mock.MagicMock()
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
s3 = client.S3Client(1) s3 = client.S3Client(1)
assert s3.client assert s3.client
assert s3.client assert s3.client
@ -45,3 +62,5 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
assert s3.client assert s3.client
assert s3.client assert s3.client
assert len(boto3.client._mock_mock_calls) == 9 assert len(boto3.client._mock_mock_calls) == 9
assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3