[pre-commit.ci] pre-commit suggestions (#17983)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com>
This commit is contained in:
parent
8f29bb561b
commit
834bd61164
|
@ -196,13 +196,12 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
|
|||
skip_begin = r"<!-- following section will be skipped from PyPI description -->"
|
||||
skip_end = r"<!-- end skipping PyPI description -->"
|
||||
# todo: wrap content as commented description
|
||||
text = re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)
|
||||
return re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)
|
||||
|
||||
# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
|
||||
# github_release_url = os.path.join(homepage, "releases", "download", version)
|
||||
# # download badge and replace url with local file
|
||||
# text = _parse_for_badge(text, github_release_url)
|
||||
return text
|
||||
|
||||
|
||||
def distribute_version(src_folder: str, ver_file: str = "version.info") -> None:
|
||||
|
|
|
@ -50,20 +50,20 @@ repos:
|
|||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.3.1
|
||||
rev: v3.8.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py38-plus]
|
||||
name: Upgrade code
|
||||
|
||||
- repo: https://github.com/PyCQA/docformatter
|
||||
rev: v1.6.3
|
||||
rev: v1.7.3
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
|
||||
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.4.0
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: yesqa
|
||||
name: Unused noqa
|
||||
|
@ -89,7 +89,7 @@ repos:
|
|||
exclude: docs/source-app
|
||||
|
||||
- repo: https://github.com/asottile/blacken-docs
|
||||
rev: 1.13.0
|
||||
rev: 1.14.0
|
||||
hooks:
|
||||
- id: blacken-docs
|
||||
args: [--line-length=120]
|
||||
|
@ -111,8 +111,8 @@ repos:
|
|||
README.md
|
||||
)$
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
rev: 'v0.0.262'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: 'v0.0.276'
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix"]
|
||||
|
|
|
@ -50,8 +50,7 @@ class Backbone(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = torch.relu(self.l1(x))
|
||||
x = torch.relu(self.l2(x))
|
||||
return x
|
||||
return torch.relu(self.l2(x))
|
||||
|
||||
|
||||
class LitClassifier(LightningModule):
|
||||
|
|
|
@ -215,9 +215,7 @@ class TransferLearningModel(LightningModule):
|
|||
x = x.squeeze(-1).squeeze(-1)
|
||||
|
||||
# 2. Classifier (returns logits):
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
return self.fc(x)
|
||||
|
||||
def loss(self, logits, labels):
|
||||
return self.loss_func(input=logits, target=labels)
|
||||
|
|
|
@ -64,8 +64,7 @@ class Generator(nn.Module):
|
|||
|
||||
def forward(self, z):
|
||||
img = self.model(z)
|
||||
img = img.view(img.size(0), *self.img_shape)
|
||||
return img
|
||||
return img.view(img.size(0), *self.img_shape)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
|
|
|
@ -33,8 +33,7 @@ class LitClassifier(LightningModule):
|
|||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = torch.relu(self.l1(x))
|
||||
x = torch.relu(self.l2(x))
|
||||
return x
|
||||
return torch.relu(self.l2(x))
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
|
|
@ -80,7 +80,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines:
|
|||
|
||||
|
||||
def _execute_app_commands(cl: CommandLines) -> None:
|
||||
"""open a subprocess shell to execute app commands.
|
||||
"""Open a subprocess shell to execute app commands.
|
||||
|
||||
The calling app environment is used in the current environment the code is running in
|
||||
"""
|
||||
|
|
|
@ -49,7 +49,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@dataclass
|
||||
class StateEntry:
|
||||
"""dataclass used to keep track the latest state shared through the app REST API."""
|
||||
"""Dataclass used to keep track the latest state shared through the app REST API."""
|
||||
|
||||
app_state: Mapping = field(default_factory=dict)
|
||||
served_state: Mapping = field(default_factory=dict)
|
||||
|
@ -76,32 +76,32 @@ class StateStore(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def get_app_state(self, k: str) -> Mapping:
|
||||
"""returns a stored appstate for an input key 'k'."""
|
||||
"""Returns a stored appstate for an input key 'k'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_served_state(self, k: str) -> Mapping:
|
||||
"""returns a last served app state for an input key 'k'."""
|
||||
"""Returns a last served app state for an input key 'k'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_served_session_id(self, k: str) -> str:
|
||||
"""returns session id for state of a key 'k'."""
|
||||
"""Returns session id for state of a key 'k'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_app_state(self, k: str, v: Mapping):
|
||||
"""sets the app state for state of a key 'k'."""
|
||||
"""Sets the app state for state of a key 'k'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_served_state(self, k: str, v: Mapping):
|
||||
"""sets the served state for state of a key 'k'."""
|
||||
"""Sets the served state for state of a key 'k'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_served_session_id(self, k: str, v: str):
|
||||
"""sets the session id for state of a key 'k'."""
|
||||
"""Sets the session id for state of a key 'k'."""
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class Auth:
|
|||
return True
|
||||
|
||||
def save(self, token: str = "", user_id: str = "", api_key: str = "", username: str = "") -> None:
|
||||
"""save credentials to disk."""
|
||||
"""Save credentials to disk."""
|
||||
self.secrets_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
with self.secrets_file.open("w") as f:
|
||||
json.dump(
|
||||
|
@ -98,7 +98,7 @@ class Auth:
|
|||
|
||||
@property
|
||||
def auth_header(self) -> Optional[str]:
|
||||
"""authentication header used by lightning-cloud client."""
|
||||
"""Authentication header used by lightning-cloud client."""
|
||||
if self.api_key:
|
||||
token = f"{self.user_id}:{self.api_key}"
|
||||
return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # E501
|
||||
|
@ -108,7 +108,7 @@ class Auth:
|
|||
)
|
||||
|
||||
def _run_server(self) -> None:
|
||||
"""start a server to complete authentication."""
|
||||
"""Start a server to complete authentication."""
|
||||
AuthServer().login_with_browser(self)
|
||||
|
||||
def authenticate(self) -> Optional[str]:
|
||||
|
|
|
@ -29,7 +29,7 @@ def _duplicate_checker(js):
|
|||
|
||||
|
||||
def string2dict(text):
|
||||
"""string2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
|
||||
"""String2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
|
||||
if not isinstance(text, str):
|
||||
text = text.decode("utf-8")
|
||||
try:
|
||||
|
|
|
@ -298,7 +298,7 @@ class _TrainingEpochLoop(loops._Loop):
|
|||
return not accumulation_done and strategy_accumulates_on_final_batch
|
||||
|
||||
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
|
||||
"""updates the lr schedulers based on the given interval."""
|
||||
"""Updates the lr schedulers based on the given interval."""
|
||||
if interval == "step" and self._should_accumulate():
|
||||
return
|
||||
self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers)
|
||||
|
|
|
@ -499,9 +499,8 @@ def target():
|
|||
|
||||
|
||||
async def async_request(url: str, data: InputRequestModel):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data.dict()) as result:
|
||||
return await result.json()
|
||||
async with aiohttp.ClientSession() as session, session.post(url, json=data.dict()) as result:
|
||||
return await result.json()
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=False, reason="No idea why... need to be fixed") # fixme
|
||||
|
|
|
@ -38,7 +38,7 @@ class CustomInfDataloader:
|
|||
|
||||
class CustomNotImplementedErrorDataloader(CustomInfDataloader):
|
||||
def __len__(self):
|
||||
"""raise NotImplementedError."""
|
||||
"""Raise NotImplementedError."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __next__(self):
|
||||
|
|
|
@ -147,7 +147,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
|
|||
|
||||
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
|
||||
def test_tensorboard_log_graph(tmpdir, example_input_array):
|
||||
"""test that log graph works with both model.example_input_array and if array is passed externally."""
|
||||
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
|
||||
# TODO(fabric): Test both nn.Module and LightningModule
|
||||
# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
|
||||
model = BoringModel()
|
||||
|
@ -170,7 +170,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
|
|||
|
||||
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
|
||||
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
|
||||
"""test that log graph throws warning if model.example_input_array is None."""
|
||||
"""Test that log graph throws warning if model.example_input_array is None."""
|
||||
model = BoringModel()
|
||||
model.example_input_array = None
|
||||
logger = TensorBoardLogger(tmpdir, log_graph=True)
|
||||
|
|
|
@ -38,7 +38,7 @@ class CustomInfDataloader:
|
|||
|
||||
class CustomNotImplementedErrorDataloader(CustomInfDataloader):
|
||||
def __len__(self):
|
||||
"""raise NotImplementedError."""
|
||||
"""Raise NotImplementedError."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __next__(self):
|
||||
|
|
|
@ -210,7 +210,7 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
|
|||
|
||||
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
|
||||
def test_tensorboard_log_graph(tmpdir, example_input_array):
|
||||
"""test that log graph works with both model.example_input_array and if array is passed externally."""
|
||||
"""Test that log graph works with both model.example_input_array and if array is passed externally."""
|
||||
model = BoringModel()
|
||||
if example_input_array is not None:
|
||||
model.example_input_array = None
|
||||
|
@ -221,7 +221,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
|
|||
|
||||
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
|
||||
def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
|
||||
"""test that log graph throws warning if model.example_input_array is None."""
|
||||
"""Test that log graph throws warning if model.example_input_array is None."""
|
||||
model = BoringModel()
|
||||
model.example_input_array = None
|
||||
logger = TensorBoardLogger(tmpdir, log_graph=True)
|
||||
|
|
|
@ -40,7 +40,7 @@ def _get_python_cprofile_total_duration(profile):
|
|||
|
||||
|
||||
def _sleep_generator(durations):
|
||||
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it
|
||||
"""The profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it
|
||||
takes to call __next__"""
|
||||
for duration in durations:
|
||||
time.sleep(duration)
|
||||
|
@ -277,7 +277,7 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
|
|||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
|
||||
"""ensure that the profiler doesn't introduce too much overhead during training."""
|
||||
"""Ensure that the profiler doesn't introduce too much overhead during training."""
|
||||
for _ in range(n_iter):
|
||||
with advanced_profiler.profile("no-op"):
|
||||
pass
|
||||
|
@ -289,7 +289,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
|
|||
|
||||
|
||||
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
|
||||
"""ensure the profiler won't fail when reporting the summary."""
|
||||
"""Ensure the profiler won't fail when reporting the summary."""
|
||||
# record at least one event
|
||||
with advanced_profiler.profile("test"):
|
||||
pass
|
||||
|
|
|
@ -266,7 +266,7 @@ def test_fx_validator_integration(tmpdir):
|
|||
|
||||
@pytest.mark.parametrize("add_dataloader_idx", [False, True])
|
||||
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
|
||||
"""test that auto_add_dataloader_idx argument works."""
|
||||
"""Test that auto_add_dataloader_idx argument works."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def val_dataloader(self):
|
||||
|
|
|
@ -245,7 +245,7 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat
|
|||
|
||||
|
||||
def test_loading_meta_tags(tmpdir):
|
||||
"""test for backward compatibility to meta_tags.csv."""
|
||||
"""Test for backward compatibility to meta_tags.csv."""
|
||||
hparams = {
|
||||
"batch_size": 32,
|
||||
"learning_rate": 0.001 * 8,
|
||||
|
|
Loading…
Reference in New Issue