From 3a8609755ca58683a75121911f92b16abf615030 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Sun, 5 Nov 2023 13:06:02 +0000 Subject: [PATCH] lightning.data: Fix some bugs with optimize (#18949) Co-authored-by: thomas --- src/lightning/data/streaming/data_processor.py | 11 ++++++----- tests/tests_data/streaming/test_data_processor.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index a106de2712..3f44cdf8a8 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(s3: Any, obj: parse.ParseResult, sleep_time: int = 2) -> Any: +def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any: """This function check.""" while True: try: - return s3.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) + return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/")) except botocore.exceptions.ClientError as e: if "the HeadObject operation: Not Found" in str(e): sleep(sleep_time) @@ -659,7 +659,7 @@ class DataChunkRecipe(DataRecipe): obj = parse.urlparse(remote_filepath) _wait_for_file_to_exist(s3, obj) with open(node_index_filepath, "wb") as f: - s3.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) + s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isdir(output_dir.path): copyfile(remote_filepath, node_index_filepath) @@ -799,15 +799,16 @@ class DataProcessor: break num_nodes = _get_num_nodes() + node_rank = _get_node_rank() # TODO: Understand why it hangs. if num_nodes == 1: for w in self.workers: w.join(0) print("Workers are finished.") - result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir) + result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) - if num_nodes == _get_node_rank() + 1: + if num_nodes == node_rank + 1: _create_dataset( input_dir=self.input_dir.path, storage_dir=self.output_dir.path, diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index 93c2ba2eb5..15dac2c268 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -204,7 +204,7 @@ def test_wait_for_file_to_exist(): raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject") return - s3.head_object = fn + s3.client.head_object = fn _wait_for_file_to_exist(s3, obj, sleep_time=0.01) @@ -213,7 +213,7 @@ def test_wait_for_file_to_exist(): def fn(*_, **__): raise ValueError("HERE") - s3.head_object = fn + s3.client.head_object = fn with pytest.raises(ValueError, match="HERE"): _wait_for_file_to_exist(s3, obj, sleep_time=0.01)