[pre-commit.ci] pre-commit suggestions (#17983)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
This commit is contained in:
pre-commit-ci[bot] 2023-08-08 16:26:06 +02:00 committed by GitHub
parent 8f29bb561b
commit 834bd61164
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 37 additions and 44 deletions

View File

@ -196,13 +196,12 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
skip_begin = r"<!-- following section will be skipped from PyPI description -->"
skip_end = r"<!-- end skipping PyPI description -->"
# todo: wrap content as commented description
text = re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)
return re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)
# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
# github_release_url = os.path.join(homepage, "releases", "download", version)
# # download badge and replace url with local file
# text = _parse_for_badge(text, github_release_url)
return text
def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:

View File

@ -50,20 +50,20 @@ repos:
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.8.0
hooks:
- id: pyupgrade
args: [--py38-plus]
name: Upgrade code
- repo: https://github.com/PyCQA/docformatter
rev: v1.6.3
rev: v1.7.3
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa
name: Unused noqa
@ -89,7 +89,7 @@ repos:
exclude: docs/source-app
- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
rev: 1.14.0
hooks:
- id: blacken-docs
args: [--line-length=120]
@ -111,8 +111,8 @@ repos:
README.md
)$
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.262'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.0.276'
hooks:
- id: ruff
args: ["--fix"]

View File

@ -50,8 +50,7 @@ class Backbone(torch.nn.Module):
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
return torch.relu(self.l2(x))
class LitClassifier(LightningModule):

View File

@ -215,9 +215,7 @@ class TransferLearningModel(LightningModule):
x = x.squeeze(-1).squeeze(-1)
# 2. Classifier (returns logits):
x = self.fc(x)
return x
return self.fc(x)
def loss(self, logits, labels):
return self.loss_func(input=logits, target=labels)

View File

@ -64,8 +64,7 @@ class Generator(nn.Module):
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
return img.view(img.size(0), *self.img_shape)
class Discriminator(nn.Module):

View File

@ -33,8 +33,7 @@ class LitClassifier(LightningModule):
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
return torch.relu(self.l2(x))
def training_step(self, batch, batch_idx):
x, y = batch

View File

@ -80,7 +80,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines:
def _execute_app_commands(cl: CommandLines) -> None:
"""open a subprocess shell to execute app commands.
"""Open a subprocess shell to execute app commands.
The calling app environment is used in the current environment the code is running in
"""

View File

@ -49,7 +49,7 @@ logger = logging.getLogger(__name__)
@dataclass
class StateEntry:
"""dataclass used to keep track the latest state shared through the app REST API."""
"""Dataclass used to keep track the latest state shared through the app REST API."""
app_state: Mapping = field(default_factory=dict)
served_state: Mapping = field(default_factory=dict)
@ -76,32 +76,32 @@ class StateStore(ABC):
@abstractmethod
def get_app_state(self, k: str) -> Mapping:
"""returns a stored appstate for an input key 'k'."""
"""Returns a stored appstate for an input key 'k'."""
pass
@abstractmethod
def get_served_state(self, k: str) -> Mapping:
"""returns a last served app state for an input key 'k'."""
"""Returns a last served app state for an input key 'k'."""
pass
@abstractmethod
def get_served_session_id(self, k: str) -> str:
"""returns session id for state of a key 'k'."""
"""Returns session id for state of a key 'k'."""
pass
@abstractmethod
def set_app_state(self, k: str, v: Mapping):
"""sets the app state for state of a key 'k'."""
"""Sets the app state for state of a key 'k'."""
pass
@abstractmethod
def set_served_state(self, k: str, v: Mapping):
"""sets the served state for state of a key 'k'."""
"""Sets the served state for state of a key 'k'."""
pass
@abstractmethod
def set_served_session_id(self, k: str, v: str):
"""sets the session id for state of a key 'k'."""
"""Sets the session id for state of a key 'k'."""
pass

View File

@ -71,7 +71,7 @@ class Auth:
return True
def save(self, token: str = "", user_id: str = "", api_key: str = "", username: str = "") -> None:
"""save credentials to disk."""
"""Save credentials to disk."""
self.secrets_file.parent.mkdir(exist_ok=True, parents=True)
with self.secrets_file.open("w") as f:
json.dump(
@ -98,7 +98,7 @@ class Auth:
@property
def auth_header(self) -> Optional[str]:
"""authentication header used by lightning-cloud client."""
"""Authentication header used by lightning-cloud client."""
if self.api_key:
token = f"{self.user_id}:{self.api_key}"
return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # E501
@ -108,7 +108,7 @@ class Auth:
)
def _run_server(self) -> None:
"""start a server to complete authentication."""
"""Start a server to complete authentication."""
AuthServer().login_with_browser(self)
def authenticate(self) -> Optional[str]:

View File

@ -29,7 +29,7 @@ def _duplicate_checker(js):
def string2dict(text):
"""string2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
"""String2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
if not isinstance(text, str):
text = text.decode("utf-8")
try:

View File

@ -298,7 +298,7 @@ class _TrainingEpochLoop(loops._Loop):
return not accumulation_done and strategy_accumulates_on_final_batch
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
"""updates the lr schedulers based on the given interval."""
"""Updates the lr schedulers based on the given interval."""
if interval == "step" and self._should_accumulate():
return
self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers)

View File

@ -499,9 +499,8 @@ def target():
async def async_request(url: str, data: InputRequestModel):
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data.dict()) as result:
return await result.json()
async with aiohttp.ClientSession() as session, session.post(url, json=data.dict()) as result:
return await result.json()
@pytest.mark.xfail(strict=False, reason="No idea why... need to be fixed") # fixme

View File

@ -38,7 +38,7 @@ class CustomInfDataloader:
class CustomNotImplementedErrorDataloader(CustomInfDataloader):
def __len__(self):
"""raise NotImplementedError."""
"""Raise NotImplementedError."""
raise NotImplementedError
def __next__(self):

View File

@ -147,7 +147,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
def test_tensorboard_log_graph(tmpdir, example_input_array):
"""test that log graph works with both model.example_input_array and if array is passed externally."""
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
# TODO(fabric): Test both nn.Module and LightningModule
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
model = BoringModel()
@ -170,7 +170,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
"""test that log graph throws warning if model.example_input_array is None."""
"""Test that log graph throws warning if model.example_input_array is None."""
model = BoringModel()
model.example_input_array = None
logger = TensorBoardLogger(tmpdir, log_graph=True)

View File

@ -38,7 +38,7 @@ class CustomInfDataloader:
class CustomNotImplementedErrorDataloader(CustomInfDataloader):
def __len__(self):
"""raise NotImplementedError."""
"""Raise NotImplementedError."""
raise NotImplementedError
def __next__(self):

View File

@ -210,7 +210,7 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
def test_tensorboard_log_graph(tmpdir, example_input_array):
"""test that log graph works with both model.example_input_array and if array is passed externally."""
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
model = BoringModel()
if example_input_array is not None:
model.example_input_array = None
@ -221,7 +221,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
"""test that log graph throws warning if model.example_input_array is None."""
"""Test that log graph throws warning if model.example_input_array is None."""
model = BoringModel()
model.example_input_array = None
logger = TensorBoardLogger(tmpdir, log_graph=True)

View File

@ -40,7 +40,7 @@ def _get_python_cprofile_total_duration(profile):
def _sleep_generator(durations):
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it
"""The profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it
takes to call __next__"""
for duration in durations:
time.sleep(duration)
@ -277,7 +277,7 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
@pytest.mark.flaky(reruns=3)
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""ensure that the profiler doesn't introduce too much overhead during training."""
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass
@ -289,7 +289,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
"""ensure the profiler won't fail when reporting the summary."""
"""Ensure the profiler won't fail when reporting the summary."""
# record at least one event
with advanced_profiler.profile("test"):
pass

View File

@ -266,7 +266,7 @@ def test_fx_validator_integration(tmpdir):
@pytest.mark.parametrize("add_dataloader_idx", [False, True])
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
"""test that auto_add_dataloader_idx argument works."""
"""Test that auto_add_dataloader_idx argument works."""
class TestModel(BoringModel):
def val_dataloader(self):

View File

@ -245,7 +245,7 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat
def test_loading_meta_tags(tmpdir):
"""test for backward compatibility to meta_tags.csv."""
"""Test for backward compatibility to meta_tags.csv."""
hparams = {
"batch_size": 32,
"learning_rate": 0.001 * 8,