Resolve bug with the uploader (#18939)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-03 19:43:55 +00:00 committed by GitHub
parent f5f4d0a264
commit f9e82c68f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 1 deletions

View File

@ -191,7 +191,6 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
)
except Exception as e:
print(e)
return
if os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
else:

View File

@ -63,6 +63,64 @@ def test_upload_fn(tmpdir):
assert os.listdir(remote_output_dir) == ["a.txt"]
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_upload_s3_fn(tmpdir, monkeypatch):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
cache_dir = os.path.join(tmpdir, "cache_dir")
os.makedirs(cache_dir, exist_ok=True)
remote_output_dir = os.path.join(tmpdir, "remote_output_dir")
os.makedirs(remote_output_dir, exist_ok=True)
filepath = os.path.join(input_dir, "a.txt")
with open(filepath, "w") as f:
f.write("HERE")
upload_queue = mock.MagicMock()
paths = [filepath, None]
def fn(*_, **__):
value = paths.pop(0)
if value is None:
return value
return value
upload_queue.get = fn
remove_queue = mock.MagicMock()
s3_client = mock.MagicMock()
called = False
def copy_file(local_filepath, *args):
nonlocal called
called = True
from shutil import copyfile
copyfile(local_filepath, os.path.join(remote_output_dir.path, os.path.basename(local_filepath)))
s3_client.client.upload_file = copy_file
monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client))
assert os.listdir(remote_output_dir) == []
assert not called
_upload_fn(upload_queue, remove_queue, cache_dir, Dir(path=remote_output_dir, url="s3://url"))
assert called
assert len(paths) == 0
assert os.listdir(remote_output_dir) == ["a.txt"]
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_remove_target(tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")