From 018a3082698fd36543ba90d13c9f74baa6549eab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 31 Oct 2023 02:16:02 +0100 Subject: [PATCH] Enable RUF018 rule for walrus assignments in asserts (#18886) --- .pre-commit-config.yaml | 12 ++++++------ pyproject.toml | 2 ++ src/lightning/app/cli/connect/app.py | 2 +- src/lightning/app/source_code/tar.py | 2 +- src/lightning/pytorch/callbacks/lr_monitor.py | 4 ++-- src/lightning/pytorch/cli.py | 4 ++-- tests/tests_app/core/test_lightning_api.py | 2 +- tests/tests_app/storage/test_copier.py | 2 +- tests/tests_app/utilities/test_cli_helpers.py | 6 +++--- tests/tests_app/utilities/test_exceptions.py | 4 ++-- .../plugins/collectives/test_torch_collective.py | 2 ++ tests/tests_fabric/utilities/test_load.py | 2 +- tests/tests_fabric/utilities/test_warnings.py | 12 ++++++------ tests/tests_pytorch/profilers/test_profiler.py | 4 ++-- .../tests_pytorch/trainer/properties/test_loggers.py | 2 +- 15 files changed, 33 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e259ecc8d..abd4238348 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -83,6 +83,12 @@ repos: - flake8-simplify - flake8-return + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.1.3" + hooks: + - id: ruff + args: ["--fix", "--preview"] + - repo: https://github.com/psf/black rev: 23.9.1 hooks: @@ -120,9 +126,3 @@ repos: - id: prettier # https://prettier.io/docs/en/options.html#print-width args: ["--print-width=120"] - - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.0.292" - hooks: - - id: ruff - args: ["--fix"] diff --git a/pyproject.toml b/pyproject.toml index 7fcb08a439..b1a0dbd9f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ select = [ "E", "W", # see: https://pypi.org/project/pycodestyle "F", # see: https://pypi.org/project/pyflakes "S", # see: https://pypi.org/project/flake8-bandit + "RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert ] extend-select = [ "I", # see: isort @@ -64,6 +65,7 @@ extend-select = [ ignore = [ "E731", # Do not assign a lambda expression, use a def "S108", + "E203", # conflicts with black ] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/lightning/app/cli/connect/app.py b/src/lightning/app/cli/connect/app.py index ebad9b1297..c2e2adc385 100644 --- a/src/lightning/app/cli/connect/app.py +++ b/src/lightning/app/cli/connect/app.py @@ -112,7 +112,7 @@ def connect_app(app_name_or_id: str): for command_name, metadata in retriever.api_commands.items(): if "cls_path" in metadata: - target_file = os.path.join(commands_folder, f"{command_name.replace(' ','_')}.py") + target_file = os.path.join(commands_folder, f"{command_name.replace(' ', '_')}.py") _download_command( command_name, metadata["cls_path"], diff --git a/src/lightning/app/source_code/tar.py b/src/lightning/app/source_code/tar.py index c3aca1ae31..5b92201df2 100644 --- a/src/lightning/app/source_code/tar.py +++ b/src/lightning/app/source_code/tar.py @@ -97,7 +97,7 @@ def _get_split_size( max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above if total_size > max_size: raise click.ClickException( - f"The size of the datastore to be uploaded is bigger than our {max_size/(1 << 40):.2f} TBytes limit" + f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit" ) split_size = minimum_split_size diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 71618c949e..4dbd2caaf5 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -272,8 +272,8 @@ class LearningRateMonitor(Callback): def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: - return f"{name}/pg{param_group_index+1}" - pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}") + return f"{name}/pg{param_group_index + 1}" + pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index + 1}") return f"{name}/{pg_name}" if use_names: pg_name = param_groups[param_group_index].get("name") diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index d36fca49c0..f41889d94a 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -627,8 +627,8 @@ class LightningCLI: if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " - f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " - "is expected to link the argument groups and implement `configure_optimizers`, see " + f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers + lr_schedulers}. In this case the " + "user is expected to link the argument groups and implement `configure_optimizers`, see " "https://lightning.ai/docs/pytorch/stable/common/lightning_cli.html" "#optimizers-and-learning-rate-schedulers" ) diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 8e5d3050f2..ed87c66de8 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -537,7 +537,7 @@ def test_configure_api(): asyncio.set_event_loop(loop) results = loop.run_until_complete(asyncio.gather(*coros)) response_time = time() - t0 - print(f"RPS: {N/response_time}") + print(f"RPS: {N / response_time}") assert response_time < 10 assert len(results) == N assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results)) diff --git a/tests/tests_app/storage/test_copier.py b/tests/tests_app/storage/test_copier.py index 2c47d83948..14d3f965ce 100644 --- a/tests/tests_app/storage/test_copier.py +++ b/tests/tests_app/storage/test_copier.py @@ -63,7 +63,7 @@ def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch): copy_request_queue.put(request) copier.run_once() response = copy_response_queue.get() - assert type(response.exception) == OSError + assert type(response.exception) is OSError assert response.exception.args[0] == "Something went wrong" diff --git a/tests/tests_app/utilities/test_cli_helpers.py b/tests/tests_app/utilities/test_cli_helpers.py index a20b12a506..54ba37ac07 100644 --- a/tests/tests_app/utilities/test_cli_helpers.py +++ b/tests/tests_app/utilities/test_cli_helpers.py @@ -51,11 +51,11 @@ def test_arrow_time_callback(): assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34:00.000") == arrow.Arrow(2022, 8, 23, 12, 34) # Just check humanized format is parsed - assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) == arrow.Arrow + assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) is arrow.Arrow - assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) == arrow.Arrow + assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) is arrow.Arrow - assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) == arrow.Arrow + assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) is arrow.Arrow # Check raising errors with pytest.raises(Exception, match="cannot parse time Mon"): diff --git a/tests/tests_app/utilities/test_exceptions.py b/tests/tests_app/utilities/test_exceptions.py index 70ed837cf0..96ac20deb9 100644 --- a/tests/tests_app/utilities/test_exceptions.py +++ b/tests/tests_app/utilities/test_exceptions.py @@ -59,7 +59,7 @@ class Test_ApiExceptionHandler: mock_subcommand.invoke.assert_called assert result.exit_code == 1 - assert type(result.exception) == ClickException + assert type(result.exception) is ClickException assert api_error_msg == str(result.exception) def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, mock_subcommand): @@ -81,4 +81,4 @@ class Test_ApiExceptionHandler: mock_subcommand.invoke.assert_called assert result.exit_code == 1 - assert type(result.exception) == ApiException + assert type(result.exception) is ApiException diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index 2f0edbcc1c..ba72286ac0 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -269,6 +269,7 @@ def _test_two_groups(strategy, left_collective, right_collective): @skip_distributed_unavailable +@pytest.mark.flaky(reruns=5) @RunIf(skip_windows=True) # unhandled timeouts @pytest.mark.xfail(raises=TimeoutError, strict=False) def test_two_groups(): @@ -286,6 +287,7 @@ def _test_default_process_group(strategy, *collectives): @skip_distributed_unavailable +@pytest.mark.flaky(reruns=5) @RunIf(skip_windows=True) # unhandled timeouts def test_default_process_group(): collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index b687984ea9..3dccf2258e 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -29,7 +29,7 @@ def test_lazy_load_module(tmp_path): model1.load_state_dict(checkpoint) assert isinstance(checkpoint["weight"], _NotYetLoadedTensor) - assert type(model0.weight.data) == torch.Tensor + assert type(model0.weight.data) is torch.Tensor assert torch.equal(model0.weight, model1.weight) assert torch.equal(model0.bias, model1.bias) diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index fcf69be1a4..1f34bf6e2a 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -62,12 +62,12 @@ if __name__ == "__main__": output = stderr.getvalue() expected_lines = [ f"test_warnings.py:{base_line}: test1", - f"test_warnings.py:{base_line+1}: test2", - f"test_warnings.py:{base_line+3}: test3", - f"test_warnings.py:{base_line+4}: test4", - f"test_warnings.py:{base_line+6}: test5", - f"test_warnings.py:{base_line+9}: test6", - f"test_warnings.py:{base_line+10}: test7", + f"test_warnings.py:{base_line + 1}: test2", + f"test_warnings.py:{base_line + 3}: test3", + f"test_warnings.py:{base_line + 4}: test4", + f"test_warnings.py:{base_line + 6}: test5", + f"test_warnings.py:{base_line + 9}: test6", + f"test_warnings.py:{base_line + 10}: test7", ] for ln in expected_lines: diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index c23bf567eb..56d82734dc 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -215,7 +215,7 @@ def test_simple_profiler_summary(tmpdir, extended): f" {'Total time (s)':<15}\t| {'Percentage %':<15}\t|" ) output_string_len = len(header_string.expandtabs()) - sep_lines = f"{sep}{'-'* output_string_len}" + sep_lines = f"{sep}{'-' * output_string_len}" expected_text = ( f"Profiler Report{sep}" f"{sep_lines}" @@ -236,7 +236,7 @@ def test_simple_profiler_summary(tmpdir, extended): f"{sep}| {'Action':<{max_action_len}s}\t| {'Mean duration (s)':<15}\t| {'Total time (s)':<15}\t|" ) output_string_len = len(header_string.expandtabs()) - sep_lines = f"{sep}{'-'* output_string_len}" + sep_lines = f"{sep}{'-' * output_string_len}" expected_text = ( f"Profiler Report{sep}" f"{sep_lines}" diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 721e167270..7fbd9197b8 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -48,7 +48,7 @@ def test_trainer_loggers_setters(): logger2 = CustomLogger() trainer = Trainer() - assert type(trainer.logger) == TensorBoardLogger + assert type(trainer.logger) is TensorBoardLogger assert trainer.loggers == [trainer.logger] # Test setters for trainer.logger