diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 5e019b625d..b322b376d8 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -191,7 +191,6 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ ) except Exception as e: print(e) - return if os.path.isdir(output_dir.path): copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) else: diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index efefb3e49e..7593fa79d7 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -63,6 +63,64 @@ def test_upload_fn(tmpdir): assert os.listdir(remote_output_dir) == ["a.txt"] +@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") +def test_upload_s3_fn(tmpdir, monkeypatch): + input_dir = os.path.join(tmpdir, "input_dir") + os.makedirs(input_dir, exist_ok=True) + + cache_dir = os.path.join(tmpdir, "cache_dir") + os.makedirs(cache_dir, exist_ok=True) + + remote_output_dir = os.path.join(tmpdir, "remote_output_dir") + os.makedirs(remote_output_dir, exist_ok=True) + + filepath = os.path.join(input_dir, "a.txt") + + with open(filepath, "w") as f: + f.write("HERE") + + upload_queue = mock.MagicMock() + + paths = [filepath, None] + + def fn(*_, **__): + value = paths.pop(0) + if value is None: + return value + return value + + upload_queue.get = fn + + remove_queue = mock.MagicMock() + + s3_client = mock.MagicMock() + + called = False + + def copy_file(local_filepath, *args): + nonlocal called + called = True + from shutil import copyfile + + copyfile(local_filepath, os.path.join(remote_output_dir.path, os.path.basename(local_filepath))) + + s3_client.client.upload_file = copy_file + + monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client)) + + assert os.listdir(remote_output_dir) == [] + + assert not called + + _upload_fn(upload_queue, remove_queue, cache_dir, Dir(path=remote_output_dir, url="s3://url")) + + assert called + + assert len(paths) == 0 + + assert os.listdir(remote_output_dir) == ["a.txt"] + + @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_remove_target(tmpdir): input_dir = os.path.join(tmpdir, "input_dir")