Resolve boto3 unable to local credentials (#19472)
This commit is contained in:
parent
d61f6fecd4
commit
71f44775c9
|
@ -7,6 +7,8 @@ from lightning.data.constants import _BOTO3_AVAILABLE
|
||||||
if _BOTO3_AVAILABLE:
|
if _BOTO3_AVAILABLE:
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
|
from botocore.credentials import InstanceMetadataProvider
|
||||||
|
from botocore.utils import InstanceMetadataFetcher
|
||||||
|
|
||||||
|
|
||||||
class S3Client:
|
class S3Client:
|
||||||
|
@ -18,20 +20,34 @@ class S3Client:
|
||||||
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
|
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
|
||||||
self._client: Optional[Any] = None
|
self._client: Optional[Any] = None
|
||||||
|
|
||||||
|
def _create_client(self) -> None:
|
||||||
|
has_shared_credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # noqa: E501
|
||||||
|
|
||||||
|
if has_shared_credentials_file:
|
||||||
|
self._client = boto3.client(
|
||||||
|
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
|
||||||
|
credentials = provider.load()
|
||||||
|
self._client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
aws_access_key_id=credentials.access_key,
|
||||||
|
aws_secret_access_key=credentials.secret_key,
|
||||||
|
aws_session_token=credentials.token,
|
||||||
|
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> Any:
|
def client(self) -> Any:
|
||||||
if not self._has_cloud_space_id:
|
if not self._has_cloud_space_id:
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
self._client = boto3.client(
|
self._create_client()
|
||||||
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
|
||||||
)
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
# Re-generate credentials for EC2
|
# Re-generate credentials for EC2
|
||||||
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
|
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
|
||||||
self._client = boto3.client(
|
self._create_client()
|
||||||
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
|
|
||||||
)
|
|
||||||
self._last_time = time()
|
self._last_time = time()
|
||||||
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
|
@ -57,7 +57,7 @@ class S3Downloader(Downloader):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0):
|
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
|
||||||
if self._s5cmd_available:
|
if self._s5cmd_available:
|
||||||
proc = subprocess.Popen(
|
proc = subprocess.Popen(
|
||||||
f"s5cmd cp {remote_filepath} {local_filepath}",
|
f"s5cmd cp {remote_filepath} {local_filepath}",
|
||||||
|
|
|
@ -13,6 +13,12 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
|
||||||
botocore = mock.MagicMock()
|
botocore = mock.MagicMock()
|
||||||
monkeypatch.setattr(client, "botocore", botocore)
|
monkeypatch.setattr(client, "botocore", botocore)
|
||||||
|
|
||||||
|
instance_metadata_provider = mock.MagicMock()
|
||||||
|
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
|
||||||
|
|
||||||
|
instance_metadata_fetcher = mock.MagicMock()
|
||||||
|
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
|
||||||
|
|
||||||
s3 = client.S3Client(1)
|
s3 = client.S3Client(1)
|
||||||
assert s3.client
|
assert s3.client
|
||||||
assert s3.client
|
assert s3.client
|
||||||
|
@ -24,7 +30,8 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
|
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
|
||||||
def test_s3_client_with_cloud_space_id(monkeypatch):
|
@pytest.mark.parametrize("use_shared_credentials", [False, True])
|
||||||
|
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
|
||||||
boto3 = mock.MagicMock()
|
boto3 = mock.MagicMock()
|
||||||
monkeypatch.setattr(client, "boto3", boto3)
|
monkeypatch.setattr(client, "boto3", boto3)
|
||||||
|
|
||||||
|
@ -33,6 +40,16 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
|
||||||
|
|
||||||
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
|
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
|
||||||
|
|
||||||
|
if use_shared_credentials:
|
||||||
|
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
|
||||||
|
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")
|
||||||
|
|
||||||
|
instance_metadata_provider = mock.MagicMock()
|
||||||
|
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
|
||||||
|
|
||||||
|
instance_metadata_fetcher = mock.MagicMock()
|
||||||
|
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
|
||||||
|
|
||||||
s3 = client.S3Client(1)
|
s3 = client.S3Client(1)
|
||||||
assert s3.client
|
assert s3.client
|
||||||
assert s3.client
|
assert s3.client
|
||||||
|
@ -45,3 +62,5 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
|
||||||
assert s3.client
|
assert s3.client
|
||||||
assert s3.client
|
assert s3.client
|
||||||
assert len(boto3.client._mock_mock_calls) == 9
|
assert len(boto3.client._mock_mock_calls) == 9
|
||||||
|
|
||||||
|
assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3
|
||||||
|
|
Loading…
Reference in New Issue