Resolve boto3 unable to local credentials (#19472)
This commit is contained in:
parent
d61f6fecd4
commit
71f44775c9
|
@ -7,6 +7,8 @@ from lightning.data.constants import _BOTO3_AVAILABLE
|
|||
if _BOTO3_AVAILABLE:
|
||||
import boto3
|
||||
import botocore
|
||||
from botocore.credentials import InstanceMetadataProvider
|
||||
from botocore.utils import InstanceMetadataFetcher
|
||||
|
||||
|
||||
class S3Client:
|
||||
|
@ -18,20 +20,34 @@ class S3Client:
|
|||
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
|
||||
self._client: Optional[Any] = None
|
||||
|
||||
def _create_client(self) -> None:
|
||||
has_shared_credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # noqa: E501
|
||||
|
||||
if has_shared_credentials_file:
|
||||
self._client = boto3.client(
|
||||
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
||||
)
|
||||
else:
|
||||
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
|
||||
credentials = provider.load()
|
||||
self._client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
if not self._has_cloud_space_id:
|
||||
if self._client is None:
|
||||
self._client = boto3.client(
|
||||
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
||||
)
|
||||
self._create_client()
|
||||
return self._client
|
||||
|
||||
# Re-generate credentials for EC2
|
||||
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
|
||||
self._client = boto3.client(
|
||||
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
||||
)
|
||||
self._create_client()
|
||||
self._last_time = time()
|
||||
|
||||
return self._client
|
||||
|
|
|
@ -57,7 +57,7 @@ class S3Downloader(Downloader):
|
|||
return
|
||||
|
||||
try:
|
||||
with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0):
|
||||
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
|
||||
if self._s5cmd_available:
|
||||
proc = subprocess.Popen(
|
||||
f"s5cmd cp {remote_filepath} {local_filepath}",
|
||||
|
|
|
@ -13,6 +13,12 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
|
|||
botocore = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "botocore", botocore)
|
||||
|
||||
instance_metadata_provider = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
|
||||
|
||||
instance_metadata_fetcher = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
|
||||
|
||||
s3 = client.S3Client(1)
|
||||
assert s3.client
|
||||
assert s3.client
|
||||
|
@ -24,7 +30,8 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
|
|||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
|
||||
def test_s3_client_with_cloud_space_id(monkeypatch):
|
||||
@pytest.mark.parametrize("use_shared_credentials", [False, True])
|
||||
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
|
||||
boto3 = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "boto3", boto3)
|
||||
|
||||
|
@ -33,6 +40,16 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
|
|||
|
||||
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
|
||||
|
||||
if use_shared_credentials:
|
||||
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
|
||||
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")
|
||||
|
||||
instance_metadata_provider = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
|
||||
|
||||
instance_metadata_fetcher = mock.MagicMock()
|
||||
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
|
||||
|
||||
s3 = client.S3Client(1)
|
||||
assert s3.client
|
||||
assert s3.client
|
||||
|
@ -45,3 +62,5 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
|
|||
assert s3.client
|
||||
assert s3.client
|
||||
assert len(boto3.client._mock_mock_calls) == 9
|
||||
|
||||
assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3
|
||||
|
|
Loading…
Reference in New Issue