Enable RUF018 rule for walrus assignments in asserts (#18886)
This commit is contained in:
parent
e0ba4d46e1
commit
018a308269
|
@ -83,6 +83,12 @@ repos:
|
|||
- flake8-simplify
|
||||
- flake8-return
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.1.3"
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix", "--preview"]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.9.1
|
||||
hooks:
|
||||
|
@ -120,9 +126,3 @@ repos:
|
|||
- id: prettier
|
||||
# https://prettier.io/docs/en/options.html#print-width
|
||||
args: ["--print-width=120"]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.0.292"
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix"]
|
||||
|
|
|
@ -53,6 +53,7 @@ select = [
|
|||
"E", "W", # see: https://pypi.org/project/pycodestyle
|
||||
"F", # see: https://pypi.org/project/pyflakes
|
||||
"S", # see: https://pypi.org/project/flake8-bandit
|
||||
"RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert
|
||||
]
|
||||
extend-select = [
|
||||
"I", # see: isort
|
||||
|
@ -64,6 +65,7 @@ extend-select = [
|
|||
ignore = [
|
||||
"E731", # Do not assign a lambda expression, use a def
|
||||
"S108",
|
||||
"E203", # conflicts with black
|
||||
]
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
|
|
|
@ -627,8 +627,8 @@ class LightningCLI:
|
|||
if len(optimizers) > 1 or len(lr_schedulers) > 1:
|
||||
raise MisconfigurationException(
|
||||
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
|
||||
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
|
||||
"is expected to link the argument groups and implement `configure_optimizers`, see "
|
||||
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers + lr_schedulers}. In this case the "
|
||||
"user is expected to link the argument groups and implement `configure_optimizers`, see "
|
||||
"https://lightning.ai/docs/pytorch/stable/common/lightning_cli.html"
|
||||
"#optimizers-and-learning-rate-schedulers"
|
||||
)
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch):
|
|||
copy_request_queue.put(request)
|
||||
copier.run_once()
|
||||
response = copy_response_queue.get()
|
||||
assert type(response.exception) == OSError
|
||||
assert type(response.exception) is OSError
|
||||
assert response.exception.args[0] == "Something went wrong"
|
||||
|
||||
|
||||
|
|
|
@ -51,11 +51,11 @@ def test_arrow_time_callback():
|
|||
assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34:00.000") == arrow.Arrow(2022, 8, 23, 12, 34)
|
||||
|
||||
# Just check humanized format is parsed
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) == arrow.Arrow
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) is arrow.Arrow
|
||||
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) == arrow.Arrow
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) is arrow.Arrow
|
||||
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) == arrow.Arrow
|
||||
assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) is arrow.Arrow
|
||||
|
||||
# Check raising errors
|
||||
with pytest.raises(Exception, match="cannot parse time Mon"):
|
||||
|
|
|
@ -59,7 +59,7 @@ class Test_ApiExceptionHandler:
|
|||
|
||||
mock_subcommand.invoke.assert_called
|
||||
assert result.exit_code == 1
|
||||
assert type(result.exception) == ClickException
|
||||
assert type(result.exception) is ClickException
|
||||
assert api_error_msg == str(result.exception)
|
||||
|
||||
def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, mock_subcommand):
|
||||
|
@ -81,4 +81,4 @@ class Test_ApiExceptionHandler:
|
|||
|
||||
mock_subcommand.invoke.assert_called
|
||||
assert result.exit_code == 1
|
||||
assert type(result.exception) == ApiException
|
||||
assert type(result.exception) is ApiException
|
||||
|
|
|
@ -269,6 +269,7 @@ def _test_two_groups(strategy, left_collective, right_collective):
|
|||
|
||||
|
||||
@skip_distributed_unavailable
|
||||
@pytest.mark.flaky(reruns=5)
|
||||
@RunIf(skip_windows=True) # unhandled timeouts
|
||||
@pytest.mark.xfail(raises=TimeoutError, strict=False)
|
||||
def test_two_groups():
|
||||
|
@ -286,6 +287,7 @@ def _test_default_process_group(strategy, *collectives):
|
|||
|
||||
|
||||
@skip_distributed_unavailable
|
||||
@pytest.mark.flaky(reruns=5)
|
||||
@RunIf(skip_windows=True) # unhandled timeouts
|
||||
def test_default_process_group():
|
||||
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
|
||||
|
|
|
@ -29,7 +29,7 @@ def test_lazy_load_module(tmp_path):
|
|||
model1.load_state_dict(checkpoint)
|
||||
|
||||
assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
|
||||
assert type(model0.weight.data) == torch.Tensor
|
||||
assert type(model0.weight.data) is torch.Tensor
|
||||
assert torch.equal(model0.weight, model1.weight)
|
||||
assert torch.equal(model0.bias, model1.bias)
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_trainer_loggers_setters():
|
|||
logger2 = CustomLogger()
|
||||
|
||||
trainer = Trainer()
|
||||
assert type(trainer.logger) == TensorBoardLogger
|
||||
assert type(trainer.logger) is TensorBoardLogger
|
||||
assert trainer.loggers == [trainer.logger]
|
||||
|
||||
# Test setters for trainer.logger
|
||||
|
|
Loading…
Reference in New Issue