lightning.data: Fix some bugs with optimize (#18949)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-05 13:06:02 +00:00 committed by GitHub
parent 0e7a3b0b5f
commit 3a8609755c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
return os.path.join(cache_dir, name.lstrip("/"))
def _wait_for_file_to_exist(s3: Any, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
"""This function check."""
while True:
try:
return s3.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
except botocore.exceptions.ClientError as e:
if "the HeadObject operation: Not Found" in str(e):
sleep(sleep_time)
@ -659,7 +659,7 @@ class DataChunkRecipe(DataRecipe):
obj = parse.urlparse(remote_filepath)
_wait_for_file_to_exist(s3, obj)
with open(node_index_filepath, "wb") as f:
s3.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
elif os.path.isdir(output_dir.path):
copyfile(remote_filepath, node_index_filepath)
@ -799,15 +799,16 @@ class DataProcessor:
break
num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
# TODO: Understand why it hangs.
if num_nodes == 1:
for w in self.workers:
w.join(0)
print("Workers are finished.")
result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir)
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
if num_nodes == _get_node_rank() + 1:
if num_nodes == node_rank + 1:
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,

View File

@ -204,7 +204,7 @@ def test_wait_for_file_to_exist():
raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
return
s3.head_object = fn
s3.client.head_object = fn
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
@ -213,7 +213,7 @@ def test_wait_for_file_to_exist():
def fn(*_, **__):
raise ValueError("HERE")
s3.head_object = fn
s3.client.head_object = fn
with pytest.raises(ValueError, match="HERE"):
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)