diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f99c51228a..e692aa8de2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,9 @@ repos: - id: yesqa name: Unused noqa additional_dependencies: + #- pep8-naming + #- flake8-pytest-style + - flake8-bandit - flake8-simplify - repo: https://github.com/PyCQA/isort diff --git a/examples/app/dag/processing.py b/examples/app/dag/processing.py index 28a07d70bd..245377fa8c 100644 --- a/examples/app/dag/processing.py +++ b/examples/app/dag/processing.py @@ -7,7 +7,7 @@ print("Starting processing ...") scaler = MinMaxScaler() X_train, X_test, y_train, y_test = train_test_split( - df_data.values, df_target.values, test_size=0.20, random_state=random.randint(0, 42) # noqa F821 + df_data.values, df_target.values, test_size=0.20, random_state=random.randint(0, 42) ) X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index a819174d77..d2c1cd1018 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -376,11 +376,11 @@ class MyCustomTrainer: try: monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])] - except KeyError as e: + except KeyError as ex: possible_keys = list(possible_monitor_vals.keys()) raise KeyError( f"monitor {scheduler_cfg['monitor']} is invalid. Possible values are {possible_keys}." - ) from e + ) from ex # rely on model hook for actual step model.lr_scheduler_step(scheduler_cfg["scheduler"], monitor) diff --git a/pyproject.toml b/pyproject.toml index 5a1e03e109..5dc67dd52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ line-length = 120 select = [ "E", "W", # see: https://pypi.org/project/pycodestyle "F", # see: https://pypi.org/project/pyflakes + "S", # see: https://pypi.org/project/flake8-bandit ] extend-select = [ "C4", # see: https://pypi.org/project/flake8-comprehensions @@ -71,6 +72,34 @@ exclude = [ ] ignore-init-module-imports = true +[tool.ruff.per-file-ignores] +".actions/*" = ["S101", "S310"] +"setup.py" = ["S101"] +"examples/**" = [ + "S101", # Use of `assert` detected + "S113", # todo: Probable use of requests call without + "S104", # Possible binding to all interface + "F821", # Undefined name `...` + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S501", # Probable use of `requests` call with `verify=False` disabling SSL certificate checks + "S108", # Probable insecure usage of temporary file or directory: "/tmp/data/MNIST" +] +"src/**" = [ + "S101", # todo: Use of `assert` detected + "S105", "S106", "S107", # todo: Possible hardcoded password: ... + "S113", # todo: Probable use of requests call without timeout + "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue + "S324", # todo: Probable use of insecure hash functions in `hashlib` +] +"tests/**" = [ + "S101", # Use of `assert` detected + "S105", "S106", # todo: Possible hardcoded password: ... + "S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue + "S113", # todo: Probable use of requests call without timeout + "S311", # todo: Standard pseudo-random generators are not suitable for cryptographic purposes + "S108", # todo: Probable insecure usage of temporary file or directory: "/tmp/sys-customizations-sync" +] + [tool.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 diff --git a/requirements/pytorch/adjust-versions.py b/requirements/pytorch/adjust-versions.py index 108de1703d..1ef74af646 100644 --- a/requirements/pytorch/adjust-versions.py +++ b/requirements/pytorch/adjust-versions.py @@ -35,7 +35,7 @@ def replace(req: str, torch_version: Optional[str] = None) -> str: import torch torch_version = torch.__version__ - assert torch_version, f"invalid torch: {torch_version}" + assert torch_version, f"invalid torch: {torch_version}" # noqa: S101 # remove comments and strip whitespace req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req).strip() diff --git a/src/lightning/app/cli/cmd_clusters.py b/src/lightning/app/cli/cmd_clusters.py index 0f943487ce..fb6a1bbafb 100644 --- a/src/lightning/app/cli/cmd_clusters.py +++ b/src/lightning/app/cli/cmd_clusters.py @@ -325,8 +325,8 @@ def _wait_for_cluster_state( break time.sleep(poll_duration_seconds) elapsed = int(time.time() - start) - except lightning_cloud.openapi.rest.ApiException as e: - if e.status == 404 and target_state == V1ClusterState.DELETED: + except lightning_cloud.openapi.rest.ApiException as ex: + if ex.status == 404 and target_state == V1ClusterState.DELETED: return raise else: diff --git a/src/lightning/app/cli/cmd_init.py b/src/lightning/app/cli/cmd_init.py index 6a415af3e8..9fb656bbb8 100644 --- a/src/lightning/app/cli/cmd_init.py +++ b/src/lightning/app/cli/cmd_init.py @@ -113,10 +113,11 @@ def _capture_valid_app_component_name(value: Optional[str] = None, resource_type raise SystemExit(m) except KeyboardInterrupt: - m = f""" + raise SystemExit( + f""" ⚡ {resource_type} init aborted! ⚡ """ - raise SystemExit(m) + ) return value diff --git a/src/lightning/app/cli/cmd_install.py b/src/lightning/app/cli/cmd_install.py index 00b306d80c..a2efd1187d 100644 --- a/src/lightning/app/cli/cmd_install.py +++ b/src/lightning/app/cli/cmd_install.py @@ -227,13 +227,14 @@ def _show_install_component_prompt(entry: Dict[str, str], component: str, org: s return git_url except KeyboardInterrupt: repo = entry["sourceUrl"] - m = f""" + raise SystemExit( + f""" ⚡ Installation aborted! ⚡ Install the component yourself by visiting: {repo} """ - raise SystemExit(m) + ) def _show_non_gallery_install_component_prompt(gh_url: str, yes_arg: bool) -> str: @@ -282,13 +283,14 @@ def _show_non_gallery_install_component_prompt(gh_url: str, yes_arg: bool) -> st return gh_url except KeyboardInterrupt: - m = f""" + raise SystemExit( + f""" ⚡ Installation aborted! ⚡ Install the component yourself by visiting: {repo_url} """ - raise SystemExit(m) + ) def _show_install_app_prompt( @@ -332,13 +334,14 @@ def _show_install_app_prompt( return source_url, git_url, folder_name, git_sha except KeyboardInterrupt: repo = entry["sourceUrl"] - m = f""" + raise SystemExit( + f""" ⚡ Installation aborted! ⚡ Install the {resource_type} yourself by visiting: {repo} """ - raise SystemExit(m) + ) def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[str, str]: @@ -352,15 +355,16 @@ def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[st folder_name = gh_url.split("/")[-1] org = re.search(r"github.com\/(.*)\/", gh_url).group(1) # type: ignore - except Exception as e: # noqa - m = """ + except Exception: + raise SystemExit( + """ Your github url is not supported. Here's the supported format: https://github.com/YourOrgName/your-repo-name Example: https://github.com/Lightning-AI/lightning """ - raise SystemExit("") + ) # yes arg does not prompt the user for permission to install anything # automatically creates env and sets up the project @@ -396,20 +400,22 @@ def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[st return gh_url, folder_name except KeyboardInterrupt: - m = f""" + raise SystemExit( + f""" ⚡ Installation aborted! ⚡ Install the app yourself by visiting {gh_url} """ - raise SystemExit(m) + ) def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, str]: # ensure resource identifier is properly formatted try: org, resource = name.split("/") - except Exception as e: # noqa - m = f""" + except Exception: + raise SystemExit( + f""" {resource_type} name format must have organization/{resource_type}-name Examples: @@ -418,12 +424,7 @@ def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, st You passed in: {name} """ - raise SystemExit(m) - m = f""" - ⚡ Installing Lightning {resource_type} ⚡ - {resource_type} name: {resource} - developer: {org} - """ + ) return org, resource @@ -466,13 +467,14 @@ def _resolve_resource( elif resource_type == "component": gallery_entries = data["components"] except requests.ConnectionError: - m = f""" + sys.tracebacklimit = 0 + raise SystemError( + f""" Network connection error, could not load list of available Lightning {resource_type}s. Try again when you have a network connection! """ - sys.tracebacklimit = 0 - raise SystemError(m) + ) entries = [] all_versions = [] @@ -582,15 +584,16 @@ def _install_app_from_source( logger.info(f"⚡ RUN: git clone {source_url}") try: subprocess.check_output(["git", "clone", git_url], stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - if "Repository not found" in str(e.output): - m = f""" + except subprocess.CalledProcessError as ex: + if "Repository not found" in str(ex.output): + raise SystemExit( + f""" Looks like the github url was not found or doesn't exist. Do you have a typo? {source_url} """ - raise SystemExit(m) + ) else: - raise Exception(e) + raise Exception(ex) # step into the repo folder os.chdir(f"{folder_name}") @@ -599,11 +602,10 @@ def _install_app_from_source( try: if git_sha: subprocess.check_output(["git", "checkout", git_sha], stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - if "did not match any" in str(e.output): + except subprocess.CalledProcessError as ex: + if "did not match any" in str(ex.output): raise SystemExit("Looks like the git SHA is not valid or doesn't exist in app repo.") - else: - raise Exception(e) + raise Exception(ex) # activate and install reqs # TODO: remove shell=True... but need to run command in venv diff --git a/src/lightning/app/cli/cmd_pl_init.py b/src/lightning/app/cli/cmd_pl_init.py index 57da584993..edf94a6324 100644 --- a/src/lightning/app/cli/cmd_pl_init.py +++ b/src/lightning/app/cli/cmd_pl_init.py @@ -119,14 +119,14 @@ def download_frontend(destination: Path) -> None: url = "https://storage.googleapis.com/grid-packages/pytorch-lightning-app/v0.0.0/build.tar.gz" build_dir_name = "build" with TemporaryDirectory() as download_dir: - response = urllib.request.urlopen(url) + response = urllib.request.urlopen(url) # noqa: S310 file = tarfile.open(fileobj=response, mode="r|gz") file.extractall(path=download_dir) shutil.move(str(Path(download_dir, build_dir_name)), destination) def project_file_from_template(template_dir: Path, destination_dir: Path, template_name: str, **kwargs: Any) -> None: - env = Environment(loader=FileSystemLoader(template_dir)) + env = Environment(loader=FileSystemLoader(template_dir)) # noqa: S701 template = env.get_template(template_name) rendered_template = template.render(**kwargs) with open(destination_dir / template_name, "w") as file: diff --git a/src/lightning/app/cli/cmd_ssh_keys.py b/src/lightning/app/cli/cmd_ssh_keys.py index ceb372df98..eaf3826572 100644 --- a/src/lightning/app/cli/cmd_ssh_keys.py +++ b/src/lightning/app/cli/cmd_ssh_keys.py @@ -59,7 +59,8 @@ class _SSHKeyManager: console.print(ssh_keys.as_table()) def add_key(self, public_key: str, name: Optional[str], comment: Optional[str]) -> None: - key_name = name if name is not None else "-".join(random.choice(string.ascii_lowercase) for _ in range(5)) + rnd = "-".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311 + key_name = name if name is not None else rnd self.api_client.s_sh_public_key_service_create_ssh_public_key( V1CreateSSHPublicKeyRequest( name=key_name, diff --git a/src/lightning/app/cli/commands/cp.py b/src/lightning/app/cli/commands/cp.py index bac2a55098..1d15f8ff7a 100644 --- a/src/lightning/app/cli/commands/cp.py +++ b/src/lightning/app/cli/commands/cp.py @@ -13,6 +13,7 @@ # limitations under the License. import concurrent +import contextlib import os import sys from functools import partial @@ -253,8 +254,8 @@ def _download_file(path: str, url: str, progress: Progress, task_id: TaskID) -> # Disable warning about making an insecure request urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - try: - request = requests.get(url, stream=True, verify=False) + with contextlib.suppress(ConnectionError): + request = requests.get(url, stream=True, verify=False) # noqa: S501 chunk_size = 1024 @@ -262,8 +263,6 @@ def _download_file(path: str, url: str, progress: Progress, task_id: TaskID) -> for chunk in request.iter_content(chunk_size=chunk_size): fp.write(chunk) # type: ignore progress.update(task_id, advance=len(chunk)) - except ConnectionError: - pass def _sanitize_path(path: str, pwd: str) -> Tuple[str, bool]: diff --git a/src/lightning/app/cli/commands/ls.py b/src/lightning/app/cli/commands/ls.py index 5427435bd9..8ba0077aa2 100644 --- a/src/lightning/app/cli/commands/ls.py +++ b/src/lightning/app/cli/commands/ls.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import os import sys from contextlib import nullcontext @@ -224,7 +224,9 @@ def _collect_artifacts( if page_token in tokens: return - try: + # Note: This is triggered when the request is wrong. + # This is currently happening due to looping through the user clusters. + with contextlib.suppress(lightning_cloud.openapi.rest.ApiException): response = client.lightningapp_instance_service_list_project_artifacts( project_id, prefix=prefix, @@ -248,10 +250,6 @@ def _collect_artifacts( page_token=response.next_page_token, tokens=tokens, ) - except lightning_cloud.openapi.rest.ApiException: - # Note: This is triggered when the request is wrong. - # This is currently happening due to looping through the user clusters. - pass def _add_resource_prefix(prefix: str, resource_path: str): diff --git a/src/lightning/app/cli/commands/rm.py b/src/lightning/app/cli/commands/rm.py index f114a27d7b..70c57a3e14 100644 --- a/src/lightning/app/cli/commands/rm.py +++ b/src/lightning/app/cli/commands/rm.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import os import click @@ -83,7 +83,7 @@ def rm(rm_path: str, r: bool = False, recursive: bool = False) -> None: succeeded = False for cluster in clusters.clusters: - try: + with contextlib.suppress(lightning_cloud.openapi.rest.ApiException): client.lightningapp_instance_service_delete_project_artifact( project_id=project_id, cluster_id=cluster.cluster_id, @@ -91,8 +91,6 @@ def rm(rm_path: str, r: bool = False, recursive: bool = False) -> None: ) succeeded = True break - except lightning_cloud.openapi.rest.ApiException: - pass prefix = os.path.join(*splits) diff --git a/src/lightning/app/cli/connect/maverick.py b/src/lightning/app/cli/connect/maverick.py index c5ea46b606..4285456a61 100644 --- a/src/lightning/app/cli/connect/maverick.py +++ b/src/lightning/app/cli/connect/maverick.py @@ -102,9 +102,9 @@ def connect_maverick(name: str, project_name: str = "") -> None: with Live(Spinner("point", text=Text("Registering maverick...", style="white")), transient=True) as live: try: register_to_cloud(name, project_name) - except Exception as e: + except Exception as ex: live.stop() - rich.print(f"[red]Failed[/red]: Registering maverick failed with error {e}") + rich.print(f"[red]Failed[/red]: Registering maverick failed with error {ex}") return live.update(Spinner("point", text=Text("Setting up ...", style="white"))) @@ -209,8 +209,8 @@ def disconnect_maverick(name: str) -> None: with Live(Spinner("point", text=Text("disconnecting maverick...", style="white")), transient=True): try: deregister_from_cloud(name) - except Exception as e: - rich.print(f"[red]Failed[/red]: Disconnecting machine failed with error: {e}") + except Exception as ex: + rich.print(f"[red]Failed[/red]: Disconnecting machine failed with error: {ex}") return subprocess.run(f"docker stop {CODE_SERVER_CONTAINER}", shell=True, capture_output=True) subprocess.run(f"docker stop {LIGHTNING_DAEMON_CONTAINER}", shell=True, capture_output=True) diff --git a/src/lightning/app/cli/lightning_cli.py b/src/lightning/app/cli/lightning_cli.py index 2370f2dbc8..3cbc8f2807 100644 --- a/src/lightning/app/cli/lightning_cli.py +++ b/src/lightning/app/cli/lightning_cli.py @@ -244,8 +244,8 @@ def cluster_logs(cluster_id: str, to_time: arrow.Arrow, from_time: arrow.Arrow, rich.print(f"[{color}]{log_event.labels.level:5}[/{color}] {date} {log_event.message.rstrip()}") except LogLinesLimitExceeded: raise click.ClickException(f"Read {limit} log lines, but there may be more. Use --limit param to read more") - except Exception as error: - logger.error(f"⚡ Error while reading logs ({type(error)}), {error}", exc_info=DEBUG) + except Exception as ex: + logger.error(f"⚡ Error while reading logs ({type(ex)}), {ex}", exc_info=DEBUG) @_main.command() diff --git a/src/lightning/app/cli/lightning_cli_create.py b/src/lightning/app/cli/lightning_cli_create.py index bf22da9976..c9ef8767b1 100644 --- a/src/lightning/app/cli/lightning_cli_create.py +++ b/src/lightning/app/cli/lightning_cli_create.py @@ -117,9 +117,9 @@ def add_ssh_key( new_public_key = Path(str(public_key)).read_text() if os.path.isfile(str(public_key)) else public_key try: ssh_key_manager.add_key(name=key_name, comment=comment, public_key=str(new_public_key)) - except ApiException as e: + except ApiException as ex: # if we got an exception it might be the user passed the private key file if os.path.isfile(str(public_key)) and os.path.isfile(f"{public_key}.pub"): ssh_key_manager.add_key(name=key_name, comment=comment, public_key=Path(f"{public_key}.pub").read_text()) else: - raise e + raise ex diff --git a/src/lightning/app/cli/lightning_cli_delete.py b/src/lightning/app/cli/lightning_cli_delete.py index cbc2f61fdb..2d90387325 100644 --- a/src/lightning/app/cli/lightning_cli_delete.py +++ b/src/lightning/app/cli/lightning_cli_delete.py @@ -211,12 +211,12 @@ def delete_app(app_name: str, cluster_id: str, skip_user_confirm_prompt: bool) - # Delete the app! app_manager = _AppManager() app_manager.delete(app_id=selected_app_instance_id) - except Exception as e: + except Exception as ex: console.print( f'[b][red]An issue occurred while deleting app "{app_name}. If the issue persists, please ' "reach out to us at [link=mailto:support@lightning.ai]support@lightning.ai[/link][/b][/red]." ) - raise click.ClickException(str(e)) + raise click.ClickException(str(ex)) console.print(f'[b][green]App "{app_name}" has been successfully deleted from cluster "{cluster_id}"![/green][/b]') return diff --git a/src/lightning/app/components/database/utilities.py b/src/lightning/app/components/database/utilities.py index 5a8f0a3a0c..4bd1d408d7 100644 --- a/src/lightning/app/components/database/utilities.py +++ b/src/lightning/app/components/database/utilities.py @@ -258,5 +258,5 @@ def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: boo logger.debug(f"Creating the following tables {models}") try: SQLModel.metadata.create_all(engine) - except Exception as e: - logger.debug(e) + except Exception as ex: + logger.debug(ex) diff --git a/src/lightning/app/core/api.py b/src/lightning/app/core/api.py index e4207525c0..57e877628e 100644 --- a/src/lightning/app/core/api.py +++ b/src/lightning/app/core/api.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import contextlib import json import os import queue @@ -124,23 +125,19 @@ class UIRefresher(Thread): raise ex def run_once(self) -> None: - try: + with contextlib.suppress(queue.Empty): global app_status state, app_status = self.api_publish_state_queue.get(timeout=0) with lock: global_app_state_store.set_app_state(TEST_SESSION_UUID, state) - except queue.Empty: - pass - try: + with contextlib.suppress(queue.Empty): responses = self.api_response_queue.get(timeout=0) with lock: # TODO: Abstract the responses store to support horizontal scaling. global responses_store for response in responses: responses_store[response["id"]] = response["response"] - except queue.Empty: - pass def join(self, timeout: Optional[float] = None) -> None: self._exit_event.set() @@ -179,7 +176,7 @@ fastapi_service.add_middleware( ) if _is_starsessions_available(): - fastapi_service.add_middleware(SessionMiddleware, secret_key="secret", autoload=True) + fastapi_service.add_middleware(SessionMiddleware, secret_key="secret", autoload=True) # noqa: S106 # General sequence is: @@ -490,7 +487,7 @@ def start_server( if uvicorn_run: host = host.split("//")[-1] if "//" in host else host - if host == "0.0.0.0": + if host == "0.0.0.0": # noqa: S104 logger.info("Your app has started.") else: logger.info(f"Your app has started. View it in your browser: http://{host}:{port}/view") diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py index 6d24e31eca..28b5e106a2 100644 --- a/src/lightning/app/core/app.py +++ b/src/lightning/app/core/app.py @@ -355,8 +355,8 @@ class LightningApp: work = None try: work = self.get_component_by_name(delta.id) - except (KeyError, AttributeError) as e: - logger.error(f"The component {delta.id} couldn't be accessed. Exception: {e}") + except (KeyError, AttributeError) as ex: + logger.error(f"The component {delta.id} couldn't be accessed. Exception: {ex}") if work: delta = _delta_to_app_state_delta( @@ -421,8 +421,8 @@ class LightningApp: for delta in deltas: try: state += delta - except Exception as e: - raise Exception(f"Current State {state}, {delta.to_dict()}") from e + except Exception as ex: + raise Exception(f"Current State {state}, {delta.to_dict()}") from ex # new_state = self.populate_changes(self.last_state, state) self.set_state(state) diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py index 19bc25ef0d..5a303b7d9c 100644 --- a/src/lightning/app/core/constants.py +++ b/src/lightning/app/core/constants.py @@ -93,7 +93,7 @@ ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "1"))) ENABLE_UPLOAD_ENDPOINT = bool(int(os.getenv("ENABLE_UPLOAD_ENDPOINT", "1"))) # directory where system customization sync files stored -SYS_CUSTOMIZATIONS_SYNC_ROOT = "/tmp/sys-customizations-sync" +SYS_CUSTOMIZATIONS_SYNC_ROOT = "/tmp/sys-customizations-sync" # noqa: S108 # todo # directory where system customization sync files will be copied to be packed into app tarball SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync" diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index f0b1e897e6..1ff3060193 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -419,7 +419,6 @@ class HTTPQueue(BaseQueue): # than the default timeout if timeout > self.default_timeout: time.sleep(0.05) - pass def _get(self) -> Any: try: diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py index 6c1747f5cb..d31785df80 100644 --- a/src/lightning/app/core/work.py +++ b/src/lightning/app/core/work.py @@ -483,10 +483,10 @@ class LightningWork: def __getattribute__(self, name: str) -> Any: try: attr = object.__getattribute__(self, name) - except AttributeError as e: - if str(e).endswith("'_state'"): + except AttributeError as ex: + if str(ex).endswith("'_state'"): raise AttributeError(f"Did you forget to call super().__init__() in {self}") - raise e + raise ex if isinstance(attr, ProxyWorkRun): return attr diff --git a/src/lightning/app/plugin/plugin.py b/src/lightning/app/plugin/plugin.py index 2b8d229dbe..b59553256c 100644 --- a/src/lightning/app/plugin/plugin.py +++ b/src/lightning/app/plugin/plugin.py @@ -128,28 +128,28 @@ def _run_plugin(run: _Run) -> Dict[str, Any]: with open(download_path, "wb") as f: f.write(response.content) - except Exception as e: + except Exception as ex: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error downloading plugin source: {str(e)}.", + detail=f"Error downloading plugin source: {str(ex)}.", ) # Extract try: with tarfile.open(download_path, "r:gz") as tf: tf.extractall(source_path) - except Exception as e: + except Exception as ex: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error extracting plugin source: {str(e)}.", + detail=f"Error extracting plugin source: {str(ex)}.", ) # Import the plugin try: plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint)) - except Exception as e: + except Exception as ex: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(e)}." + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(ex)}." ) # Ensure that apps are dispatched from the temp directory @@ -165,9 +165,9 @@ def _run_plugin(run: _Run) -> Dict[str, Any]: ) actions = plugin.run(**run.plugin_arguments) or [] return {"actions": [action.to_spec().to_dict() for action in actions]} - except Exception as e: + except Exception as ex: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}." + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(ex)}." ) finally: os.chdir(cwd) diff --git a/src/lightning/app/runners/backends/mp_process.py b/src/lightning/app/runners/backends/mp_process.py index 10d9e14d12..554a03c5c8 100644 --- a/src/lightning/app/runners/backends/mp_process.py +++ b/src/lightning/app/runners/backends/mp_process.py @@ -80,7 +80,7 @@ class MultiProcessingBackend(Backend): if constants.LIGHTNING_CLOUDSPACE_HOST is not None: # Override the port if set by the user work._port = find_free_network_port() - work._host = "0.0.0.0" + work._host = "0.0.0.0" # noqa: S104 work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}" app.processes[work.name] = MultiProcessWorkManager(app, work) @@ -121,7 +121,7 @@ class CloudMultiProcessingBackend(MultiProcessingBackend): self.ports = [] def create_work(self, app, work) -> None: - work._host = "0.0.0.0" + work._host = "0.0.0.0" # noqa: S104 nc = enable_port() self.ports.append(nc.port) work._port = nc.port diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py index f1dcd2eca4..269db0215e 100644 --- a/src/lightning/app/runners/cloud.py +++ b/src/lightning/app/runners/cloud.py @@ -186,8 +186,8 @@ class CloudRuntime(Runtime): if "PYTEST_CURRENT_TEST" not in os.environ: click.launch(self._get_cloudspace_url(project, cloudspace_name, "code", needs_credits)) - except ApiException as e: - logger.error(e.body) + except ApiException as ex: + logger.error(ex.body) sys.exit(1) def cloudspace_dispatch( @@ -384,8 +384,8 @@ class CloudRuntime(Runtime): if bool(int(os.getenv("LIGHTING_TESTING", "0"))): print(f"APP_LOGS_URL: {self._get_app_url(project, run_instance, 'logs')}") - except ApiException as e: - logger.error(e.body) + except ApiException as ex: + logger.error(ex.body) sys.exit(1) finally: if cleanup_handle: @@ -399,8 +399,8 @@ class CloudRuntime(Runtime): try: app = load_app_from_file(filepath, raise_exception=True, mock_imports=True, env_vars=env_vars) - except FileNotFoundError as e: - raise e + except FileNotFoundError as ex: + raise ex except Exception: from lightning.app.testing.helpers import EmptyFlow @@ -770,7 +770,7 @@ class CloudRuntime(Runtime): if cloudspace is not None and cloudspace.code_config is not None: data_connection_mounts = cloudspace.code_config.data_connection_mounts - random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311 work_spec = V1LightningworkSpec( build_spec=build_spec, drives=drives + mounts, diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py index ce66ba6d0b..93f091f870 100644 --- a/src/lightning/app/runners/multiprocess.py +++ b/src/lightning/app/runners/multiprocess.py @@ -51,7 +51,7 @@ class MultiProcessRuntime(Runtime): # Note: In case the runtime is used in the cloud. in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None - self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host + self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host # noqa: S104 self.app.backend = self.backend self.backend._prepare_queues(self.app) @@ -68,7 +68,7 @@ class MultiProcessRuntime(Runtime): for frontend in self.app.frontends.values(): port = find_free_network_port() - server_host = "0.0.0.0" if in_cloudspace else "localhost" + server_host = "0.0.0.0" if in_cloudspace else "localhost" # noqa: S104 server_target = ( f"https://{port}-{constants.LIGHTNING_CLOUDSPACE_HOST}" if in_cloudspace diff --git a/src/lightning/app/source_code/hashing.py b/src/lightning/app/source_code/hashing.py index a2aa5df83d..362d32f259 100644 --- a/src/lightning/app/source_code/hashing.py +++ b/src/lightning/app/source_code/hashing.py @@ -38,7 +38,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int if algorithm == "blake2": h = hashlib.blake2b(digest_size=20) elif algorithm == "md5": - h = hashlib.md5() + h = hashlib.md5() # noqa: S324 else: raise ValueError(f"Algorithm {algorithm} not supported") diff --git a/src/lightning/app/storage/copier.py b/src/lightning/app/storage/copier.py index 6715e19a15..3f2c2c3a50 100644 --- a/src/lightning/app/storage/copier.py +++ b/src/lightning/app/storage/copier.py @@ -128,9 +128,9 @@ def _copy_files( fs.makedirs(str(to_path.parent), exist_ok=True) fs.put(str(from_path), str(to_path), recursive=False) - except Exception as e: + except Exception as ex: # Return the exception so that it can be handled in the main thread - return e + return ex # NOTE: Cannot use `S3FileSystem.put(recursive=True)` because it tries to access parent directories # which it does not have access to. diff --git a/src/lightning/app/storage/path.py b/src/lightning/app/storage/path.py index 0c12c648cd..f0b7ee9560 100644 --- a/src/lightning/app/storage/path.py +++ b/src/lightning/app/storage/path.py @@ -129,7 +129,7 @@ class Path(PathlibPath): if self._origin is None: return None contents = f"{self.origin_name}/{self}" - return hashlib.sha1(contents.encode("utf-8")).hexdigest() + return hashlib.sha1(contents.encode("utf-8")).hexdigest() # noqa: S324 @property def parents(self) -> Sequence["Path"]: @@ -363,8 +363,8 @@ class Path(PathlibPath): try: _copy_files(source_path, destination_path) _logger.debug(f"All files copied from {request.path} to {response.path}.") - except Exception as e: - response.exception = e + except Exception as ex: + response.exception = ex return response diff --git a/src/lightning/app/storage/payload.py b/src/lightning/app/storage/payload.py index 8e32e0145d..f90415fe18 100644 --- a/src/lightning/app/storage/payload.py +++ b/src/lightning/app/storage/payload.py @@ -65,7 +65,7 @@ class _BasePayload(ABC): if self._origin is None: return None contents = f"{self.origin_name}/{self.consumer_name}/{self.name}" - return hashlib.sha1(contents.encode("utf-8")).hexdigest() + return hashlib.sha1(contents.encode("utf-8")).hexdigest() # noqa: S324 @property def origin_name(self) -> str: @@ -251,8 +251,8 @@ class _BasePayload(ABC): response.size = source_path.stat().st_size _copy_files(source_path, destination_path) _logger.debug(f"All files copied from {request.path} to {response.path}.") - except Exception as e: - response.exception = e + except Exception as ex: + response.exception = ex return response diff --git a/src/lightning/app/utilities/cli_helpers.py b/src/lightning/app/utilities/cli_helpers.py index d334dedc5e..1dc318d0f7 100644 --- a/src/lightning/app/utilities/cli_helpers.py +++ b/src/lightning/app/utilities/cli_helpers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import functools import json import os @@ -209,14 +209,12 @@ class _LightningAppOpenAPIRetriever: # 2: If no identifier has been provided, evaluate the local application if self.app_id_or_name_or_url is None: - try: + with contextlib.suppress(requests.exceptions.ConnectionError): self.url = f"http://localhost:{APP_SERVER_PORT}" resp = requests.get(f"{self.url}/openapi.json") if resp.status_code != 200: raise Exception(f"The server didn't process the request properly. Found {resp.json()}") self.openapi = resp.json() - except requests.exceptions.ConnectionError: - pass # 3: If an identified was provided or the local evaluation has failed, evaluate the cloud. else: diff --git a/src/lightning/app/utilities/clusters.py b/src/lightning/app/utilities/clusters.py index 2a43186e5a..a083e41c71 100644 --- a/src/lightning/app/utilities/clusters.py +++ b/src/lightning/app/utilities/clusters.py @@ -48,4 +48,4 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str: if len(clusters) == 0: raise RuntimeError(f"No clusters found on `{client.api_client.configuration.host}`.") - return random.choice(clusters).id + return random.choice(clusters).id # noqa: S311 diff --git a/src/lightning/app/utilities/commands/base.py b/src/lightning/app/utilities/commands/base.py index d8ca604261..93f1a3612d 100644 --- a/src/lightning/app/utilities/commands/base.py +++ b/src/lightning/app/utilities/commands/base.py @@ -44,9 +44,9 @@ def makedirs(path: str): r"""Recursive directory creation function.""" try: os.makedirs(osp.expanduser(osp.normpath(path))) - except OSError as e: - if e.errno != errno.EEXIST and osp.isdir(path): - raise e + except OSError as ex: + if ex.errno != errno.EEXIST and osp.isdir(path): + raise ex class ClientCommand: @@ -226,9 +226,9 @@ def _process_api_request(app, request: _APIRequest): method = getattr(flow, request.method_name) try: response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200) - except HTTPException as e: - logger.error(repr(e)) - response = _RequestResponse(status_code=e.status_code, content=e.detail) + except HTTPException as ex: + logger.error(repr(ex)) + response = _RequestResponse(status_code=ex.status_code, content=ex.detail) except Exception: logger.error(traceback.print_exc()) response = _RequestResponse(status_code=500) @@ -244,9 +244,9 @@ def _process_command_requests(app, request: _CommandRequest): # Validation is done on the CLI side. try: response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200) - except HTTPException as e: - logger.error(repr(e)) - response = _RequestResponse(status_code=e.status_code, content=e.detail) + except HTTPException as ex: + logger.error(repr(ex)) + response = _RequestResponse(status_code=ex.status_code, content=ex.detail) except Exception: logger.error(traceback.print_exc()) response = _RequestResponse(status_code=500) diff --git a/src/lightning/app/utilities/layout.py b/src/lightning/app/utilities/layout.py index d9d2175834..7b02bd6e7e 100644 --- a/src/lightning/app/utilities/layout.py +++ b/src/lightning/app/utilities/layout.py @@ -38,7 +38,7 @@ def _add_comment_to_literal_code(method, contains, comment): return "\n".join(lines) - except Exception as e: # noqa + except Exception: return "" diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py index 008cdc05ce..f22aafa3ce 100644 --- a/src/lightning/app/utilities/load_app.py +++ b/src/lightning/app/utilities/load_app.py @@ -74,12 +74,12 @@ def _load_objects_from_file( with _add_to_env(env_vars), _patch_sys_argv(): if mock_imports: with _mock_missing_imports(): - exec(code, module.__dict__) + exec(code, module.__dict__) # noqa: S102 else: - exec(code, module.__dict__) - except Exception as e: + exec(code, module.__dict__) # noqa: S102 + except Exception as ex: if raise_exception: - raise e + raise ex _prettifiy_exception(filepath) return [v for v in module.__dict__.values() if isinstance(v, target_type)], module diff --git a/src/lightning/app/utilities/login.py b/src/lightning/app/utilities/login.py index d4ad97c2a4..b5aa5aef3e 100644 --- a/src/lightning/app/utilities/login.py +++ b/src/lightning/app/utilities/login.py @@ -161,13 +161,13 @@ class AuthServer: try: # check if server is reachable or catch any network errors requests.head(url) - except requests.ConnectionError as e: + except requests.ConnectionError as ex: raise requests.ConnectionError( - f"No internet connection available. Please connect to a stable internet connection \n{e}" # E501 + f"No internet connection available. Please connect to a stable internet connection \n{ex}" # E501 ) - except requests.RequestException as e: + except requests.RequestException as ex: raise requests.RequestException( - f"An error occurred with the request. Please report this issue to Lightning Team \n{e}" # E501 + f"An error occurred with the request. Please report this issue to Lightning Team \n{ex}" # E501 ) logger.info( diff --git a/src/lightning/app/utilities/name_generator.py b/src/lightning/app/utilities/name_generator.py index 4966927566..28c43c241c 100644 --- a/src/lightning/app/utilities/name_generator.py +++ b/src/lightning/app/utilities/name_generator.py @@ -1354,5 +1354,5 @@ def get_unique_name(): >>> get_unique_name() 'truthful-dijkstra-2286' """ - adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) + adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311 return f"{adjective}-{surname}-{i}" diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index 87f556717f..6ff4871e7c 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -136,23 +136,23 @@ def _retry_wrapper(self, func: Callable) -> Callable: while _get_next_backoff_time(consecutive_errors) != _DEFAULT_BACKOFF_MAX: try: return func(self, *args, **kwargs) - except lightning_cloud.openapi.rest.ApiException as e: + except lightning_cloud.openapi.rest.ApiException as ex: # retry if the control plane fails with all errors except 4xx but not 408 - (Request Timeout) - if e.status == 408 or e.status == 409 or not str(e.status).startswith("4"): + if ex.status == 408 or ex.status == 409 or not str(ex.status).startswith("4"): consecutive_errors += 1 backoff_time = _get_next_backoff_time(consecutive_errors) logger.debug( - f"The {func.__name__} request failed to reach the server, got a response {e.status}." + f"The {func.__name__} request failed to reach the server, got a response {ex.status}." f" Retrying after {backoff_time} seconds." ) time.sleep(backoff_time) else: - raise e - except urllib3.exceptions.HTTPError as e: + raise ex + except urllib3.exceptions.HTTPError as ex: consecutive_errors += 1 backoff_time = _get_next_backoff_time(consecutive_errors) logger.debug( - f"The {func.__name__} request failed to reach the server, got a an error {str(e)}." + f"The {func.__name__} request failed to reach the server, got a an error {str(ex)}." f" Retrying after {backoff_time} seconds." ) time.sleep(backoff_time) diff --git a/src/lightning/app/utilities/openapi.py b/src/lightning/app/utilities/openapi.py index 2482bee225..d12b9da373 100644 --- a/src/lightning/app/utilities/openapi.py +++ b/src/lightning/app/utilities/openapi.py @@ -35,8 +35,8 @@ def string2dict(text): try: js = json.loads(text, object_pairs_hook=_duplicate_checker) return js - except ValueError as e: - raise ValueError(f"Unable to load JSON: {str(e)}.") + except ValueError as ex: + raise ValueError(f"Unable to load JSON: {str(ex)}.") def is_openapi(obj): diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py index b99988de02..378a4251e6 100644 --- a/src/lightning/app/utilities/packaging/lightning_utils.py +++ b/src/lightning/app/utilities/packaging/lightning_utils.py @@ -49,7 +49,7 @@ def download_frontend(root: str = _PROJECT_ROOT): shutil.rmtree(frontend_dir, ignore_errors=True) - response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL) + response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL) # noqa: S310 file = tarfile.open(fileobj=response, mode="r|gz") file.extractall(path=download_dir) diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py index 2b061f895b..fda1175f50 100644 --- a/src/lightning/app/utilities/proxies.py +++ b/src/lightning/app/utilities/proxies.py @@ -297,10 +297,10 @@ class WorkStateObserver(Thread): try: with _state_observer_lock: self._work.apply_flow_delta(Delta(deep_diff, raise_errors=True)) - except Exception as e: + except Exception as ex: print(traceback.print_exc()) - self._error_queue.put(e) - raise e + self._error_queue.put(ex) + raise ex def join(self, timeout: Optional[float] = None) -> None: self._exit_event.set() @@ -415,19 +415,19 @@ class WorkRunner: self.state_observer.join(0) self.state_observer = None self.copier.join(0) - except LightningSigtermStateException as e: + except LightningSigtermStateException as ex: logger.debug("Exiting") - os._exit(e.exit_code) - except Exception as e: + os._exit(ex.exit_code) + except Exception as ex: # Inform the flow the work failed. This would fail the entire application. - self.error_queue.put(e) + self.error_queue.put(ex) # Terminate the threads if self.state_observer: if self.state_observer.started: self.state_observer.join(0) self.state_observer = None self.copier.join(0) - raise e + raise ex def setup(self): from lightning.app.utilities.state import AppState @@ -494,7 +494,7 @@ class WorkRunner: # Set the internal IP address. # Set this here after the state observer is initialized, since it needs to record it as a change and send # it back to the flow - default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" + default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104 self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip) # 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't @@ -525,9 +525,9 @@ class WorkRunner: # If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook. try: ret = self.run_executor_cls(self.work, work_run, self.delta_queue)(*args, **kwargs) - except LightningSigtermStateException as e: - raise e - except BaseException as e: + except LightningSigtermStateException as ex: + raise ex + except BaseException as ex: # 10.2 Send failed delta to the flow. reference_state = deepcopy(self.work.state) exp, val, tb = sys.exc_info() @@ -559,7 +559,7 @@ class WorkRunner: id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state, verbose_level=2)) ) ) - self.work.on_exception(e) + self.work.on_exception(ex) print("########## CAPTURED EXCEPTION ###########") print(traceback.print_exc()) print("########## CAPTURED EXCEPTION ###########") diff --git a/src/lightning/app/utilities/scheduler.py b/src/lightning/app/utilities/scheduler.py index 5a766c8dcc..aa661aaac8 100644 --- a/src/lightning/app/utilities/scheduler.py +++ b/src/lightning/app/utilities/scheduler.py @@ -32,12 +32,9 @@ class SchedulerThread(threading.Thread): self._app = app def run(self) -> None: - try: - while not self._exit_event.is_set(): - self._exit_event.wait(self._sleep_time) - self.run_once() - except Exception as e: - raise e + while not self._exit_event.is_set(): + self._exit_event.wait(self._sleep_time) + self.run_once() def run_once(self): for call_hash in list(self._app._schedules.keys()): diff --git a/src/lightning/app/utilities/state.py b/src/lightning/app/utilities/state.py index aedacfe827..993a3a2791 100644 --- a/src/lightning/app/utilities/state.py +++ b/src/lightning/app/utilities/state.py @@ -166,8 +166,8 @@ class AppState: try: # TODO: Send the delta directly to the REST API. response = self._session.post(app_url, json=data, headers=headers) - except ConnectionError as e: - raise AttributeError("Failed to connect and send the app state. Is the app running?") from e + except ConnectionError as ex: + raise AttributeError("Failed to connect and send the app state. Is the app running?") from ex if response and response.status_code != 200: raise Exception(f"The response from the server was {response.status_code}. Your inputs were rejected.") @@ -186,8 +186,8 @@ class AppState: sleep(0.5) try: response = self._session.get(app_url, headers=headers, timeout=1) - except ConnectionError as e: - raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e + except ConnectionError as ex: + raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from ex self._authorized = response.status_code if self._authorized != 200: diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 4977b0e1a7..b6432f8111 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -331,9 +331,7 @@ def _device_count_nvml() -> int: for idx, val in enumerate(visible_devices): if cast(int, val) >= raw_cnt: return idx - except OSError: - return -1 - except AttributeError: + except (OSError, AttributeError): return -1 return len(visible_devices) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 1ea3a19fa8..3a5b468f3a 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -200,8 +200,9 @@ class TensorBoardLogger(Logger): self.experiment.add_scalar(k, v, step) # TODO(fabric): specify the possible exception except Exception as ex: - m = f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." - raise ValueError(m) from ex + raise ValueError( + f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." + ) from ex @rank_zero_only def log_hyperparams( diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 14b033f3ec..fc7c2fe782 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -218,10 +218,10 @@ def _dataloader_init_kwargs_resolve_sampler( batch_size=batch_sampler.batch_size, drop_last=batch_sampler.drop_last, ) - except TypeError as e: + except TypeError as ex: import re - match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e)) + match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex)) if not match: # an unexpected `TypeError`, continue failure raise @@ -232,7 +232,7 @@ def _dataloader_init_kwargs_resolve_sampler( "We tried to re-instantiate your custom batch sampler and failed. " "To mitigate this, either follow the API of `BatchSampler` or instantiate " "your custom batch sampler inside `*_dataloader` hooks of your module." - ) from e + ) from ex return { "sampler": None, @@ -257,12 +257,12 @@ def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optio try: result = constructor(*args, **kwargs) - except TypeError as e: + except TypeError as ex: # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass # `__init__` arguments map to one `DataLoader.__init__` argument import re - match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(ex)) if not match: # an unexpected `TypeError`, continue failure raise @@ -274,7 +274,7 @@ def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optio f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." " This argument was automatically passed to your object by PyTorch Lightning." ) - raise MisconfigurationException(message) from e + raise MisconfigurationException(message) from ex attrs_record = getattr(orig_object, "__pl_attrs_record", []) for args, fn in attrs_record: diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 710d58004a..c3c6852a76 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -63,7 +63,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int: - return random.randint(min_seed_value, max_seed_value) + return random.randint(min_seed_value, max_seed_value) # noqa: S311 def reset_seed() -> None: diff --git a/src/lightning/pytorch/_graveyard/legacy_import_unpickler.py b/src/lightning/pytorch/_graveyard/legacy_import_unpickler.py index 6dd6da172c..f236d43cd7 100644 --- a/src/lightning/pytorch/_graveyard/legacy_import_unpickler.py +++ b/src/lightning/pytorch/_graveyard/legacy_import_unpickler.py @@ -1,3 +1,4 @@ +import contextlib import pickle import warnings from typing import Any, Callable @@ -32,15 +33,12 @@ def compare_version(package: str, op: Callable, version: str, use_base_version: # patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to lightning.pytorch, # which has to be redirected to the unified package: # https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96 -try: +with contextlib.suppress(AttributeError): if hasattr(torchmetrics.utilities.imports, "_compare_version"): torchmetrics.utilities.imports._compare_version = compare_version # type: ignore -except AttributeError: - pass -try: +with contextlib.suppress(AttributeError): if hasattr(torchmetrics.metric, "_compare_version"): torchmetrics.metric._compare_version = compare_version -except AttributeError: - pass + pickle.Unpickler = RedirectingUnpickler # type: ignore diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 5aa4abe613..2891f3db01 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -13,6 +13,7 @@ # limitations under the License. import ast +import contextlib import csv import inspect import logging @@ -272,10 +273,8 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di hparams = yaml.full_load(fp) if _OMEGACONF_AVAILABLE and use_omegaconf: - try: + with contextlib.suppress(UnsupportedValueType, ValidationError): return OmegaConf.create(hparams) - except (UnsupportedValueType, ValidationError): - pass return hparams diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 09c57e0471..389aacfe0c 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -98,7 +98,7 @@ class _MNIST(Dataset): for url in self.RESOURCES: logging.info(f"Downloading {url}") fpath = os.path.join(data_folder, os.path.basename(url)) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]: @@ -112,7 +112,7 @@ class _MNIST(Dataset): # todo: specify the possible exception except Exception as ex: exception = ex - time.sleep(delta * random.random()) + time.sleep(delta * random.random()) # noqa: S311 else: break assert res is not None @@ -135,8 +135,8 @@ def MNIST(*args: Any, **kwargs: Any) -> Dataset: from torchvision.datasets import MNIST MNIST(_DATASETS_PATH, download=True) - except HTTPError as e: - print(f"Error {e} downloading `torchvision.datasets.MNIST`") + except HTTPError as ex: + print(f"Error {ex} downloading `torchvision.datasets.MNIST`") torchvision_mnist_available = False if not torchvision_mnist_available: print("`torchvision.datasets.MNIST` not available. Using our hosted version") diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 5c237ddfbc..5dccdcd5c6 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -41,9 +41,9 @@ if _COMET_AVAILABLE: try: from comet_ml.api import API - except ModuleNotFoundError: # pragma: no-cover + except ModuleNotFoundError: # For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300 - from comet_ml.papi import API # pragma: no-cover + from comet_ml.papi import API else: # needed for test mocks, these tests shall be updated comet_ml = None diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index faecc0a5e7..8c409762bf 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -294,11 +294,9 @@ class NeptuneLogger(Logger): args["run"] = self._run_short_id # Backward compatibility in case of previous version retrieval - try: + with contextlib.suppress(AttributeError): if self._run_name is not None: args["name"] = self._run_name - except AttributeError: - pass return args diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8facb02042..8df73c891c 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -53,9 +53,9 @@ class _DataFetcher(Iterator): assert self.iterator is not None try: data = next(self.iterator) - except StopIteration as e: + except StopIteration as ex: self.done = True - raise e + raise ex finally: self._stop_profiler() self.fetched += 1 @@ -128,9 +128,9 @@ class _PrefetchDataFetcher(_DataFetcher): self._fetch_next_batch(self.iterator) # consume the batch we just fetched batch = self.batches.pop(0) - except StopIteration as e: + except StopIteration as ex: self.done = True - raise e + raise ex else: # the iterator is empty raise StopIteration diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index f169e8feb9..509f40de0f 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -41,8 +41,8 @@ class AsyncCheckpointIO(_WrappingCheckpointIO): try: assert self.checkpoint_io is not None self.checkpoint_io.save_checkpoint(*args, **kwargs) - except BaseException as e: - self._error = e + except BaseException as ex: + self._error = ex self._executor.submit(_save_checkpoint, *args, **kwargs) diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index 3abc1935ec..e6db99091e 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -1,3 +1,4 @@ +import contextlib import logging import time from multiprocessing import Process @@ -99,11 +100,9 @@ class ServableModuleValidator(Callback): ready = False t0 = time.time() while not ready: - try: + with contextlib.suppress(requests.exceptions.ConnectionError): resp = requests.get(f"http://{self.host}:{self.port}/ping") ready = resp.status_code == 200 - except requests.exceptions.ConnectionError: - pass if time.time() - t0 > self.timeout: process.kill() raise Exception(f"The server didn't start within {self.timeout} seconds.") diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index afe3e2a8b7..53ee9c571e 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -282,8 +282,8 @@ class FSDPStrategy(ParallelStrategy): invalid_params_error = False try: super().setup_optimizers(trainer) - except ValueError as e: - if "optimizer got an empty parameter list" not in str(e): + except ValueError as ex: + if "optimizer got an empty parameter list" not in str(ex): raise invalid_params_error = True diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 47f9738ed4..96299126b4 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from collections.abc import Iterable from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union @@ -151,11 +152,9 @@ class _MaxSize(_ModeIterator[List]): out = [None] * n all_exhausted = True for i in range(n): - try: + with contextlib.suppress(StopIteration): out[i] = next(self.iterators[i]) all_exhausted = False - except StopIteration: - pass if all_exhausted: raise StopIteration return out diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index c94eea7398..c9c2c71517 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -297,10 +297,10 @@ def _dataloader_init_kwargs_resolve_sampler( batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) - except TypeError as e: + except TypeError as ex: import re - match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e)) + match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex)) if not match: # an unexpected `TypeError`, continue failure raise @@ -311,7 +311,7 @@ def _dataloader_init_kwargs_resolve_sampler( "We tried to re-instantiate your custom batch sampler and failed. " "To mitigate this, either follow the API of `BatchSampler` or instantiate " "your custom batch sampler inside `*_dataloader` hooks of your module." - ) from e + ) from ex if is_predicting: batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index b990a648a4..e2daae14e8 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -113,12 +113,12 @@ def _RunIf( if bf16_cuda: try: cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) - except (AssertionError, RuntimeError) as e: + except (AssertionError, RuntimeError) as ex: # AssertionError: Torch not compiled with CUDA enabled # RuntimeError: Found no NVIDIA driver on your system. - is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e) + is_unrelated = "Found no NVIDIA driver" not in str(ex) or "Torch not compiled with CUDA" not in str(ex) if is_unrelated: - raise e + raise ex cond = True conditions.append(cond) diff --git a/tests/integrations_app/conftest.py b/tests/integrations_app/conftest.py index 1e79f9de60..2efd435e66 100644 --- a/tests/integrations_app/conftest.py +++ b/tests/integrations_app/conftest.py @@ -1,3 +1,4 @@ +import contextlib import os import shutil import threading @@ -38,7 +39,7 @@ def pytest_sessionfinish(session, exitstatus): # TODO this isn't great. We should have each tests doing it's own cleanup current_process = psutil.Process() for child in current_process.children(recursive=True): - try: + with contextlib.suppress(psutil.NoSuchProcess): params = child.as_dict() or {} cmd_lines = params.get("cmdline", []) # we shouldn't kill the resource tracker from multiprocessing. If we do, @@ -46,8 +47,6 @@ def pytest_sessionfinish(session, exitstatus): if cmd_lines and "resource_tracker" in cmd_lines[-1]: continue child.kill() - except psutil.NoSuchProcess: - pass main_thread = threading.current_thread() for t in threading.enumerate(): diff --git a/tests/integrations_app/flagship/test_flashy.py b/tests/integrations_app/flagship/test_flashy.py index fbaad8efec..e9237e4080 100644 --- a/tests/integrations_app/flagship/test_flashy.py +++ b/tests/integrations_app/flagship/test_flashy.py @@ -1,3 +1,4 @@ +import contextlib from time import sleep import pytest @@ -21,17 +22,12 @@ def validate_app_functionalities(app_page: "Page") -> None: app_page: The UI page of the app to be validated. """ while True: - try: + with contextlib.suppress(playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError): app_page.reload() sleep(5) app_label = app_page.frame_locator("iframe").locator("text=Choose your AI task") app_label.wait_for(timeout=30 * 1000) break - except ( - playwright._impl._api_types.Error, - playwright._impl._api_types.TimeoutError, - ): - pass input_field = app_page.frame_locator("iframe").locator('input:below(:text("Data URL"))').first input_field.wait_for(timeout=1000) diff --git a/tests/tests_app/cli/test_cmd_apps.py b/tests/tests_app/cli/test_cmd_apps.py index d95a158e58..afb7ae77d7 100644 --- a/tests/tests_app/cli/test_cmd_apps.py +++ b/tests/tests_app/cli/test_cmd_apps.py @@ -81,7 +81,7 @@ def test_list_all_apps_paginated(list_memberships: mock.MagicMock, list_instance list_memberships.assert_called_once() assert list_instances.mock_calls == [ mock.call(project_id="default-project", limit=100, phase_in=[]), - mock.call(project_id="default-project", page_token="page-2", limit=100, phase_in=[]), + mock.call(project_id="default-project", page_token="page-2", limit=100, phase_in=[]), # noqa: S106 ] diff --git a/tests/tests_app/cli/test_cmd_install.py b/tests/tests_app/cli/test_cmd_install.py index bc8998722e..fff9f8c373 100644 --- a/tests/tests_app/cli/test_cmd_install.py +++ b/tests/tests_app/cli/test_cmd_install.py @@ -43,8 +43,8 @@ def test_valid_unpublished_app_name(): subprocess.check_output(f"lightning install app {real_app}", shell=True, stderr=subprocess.STDOUT) # this condition should never be hit assert False - except subprocess.CalledProcessError as e: - assert "WARNING" in str(e.output) + except subprocess.CalledProcessError as ex: + assert "WARNING" in str(ex.output) # assert aborted install result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app], input="q") diff --git a/tests/tests_app/components/database/test_client_server.py b/tests/tests_app/components/database/test_client_server.py index bfcc435e80..332b18cc8e 100644 --- a/tests/tests_app/components/database/test_client_server.py +++ b/tests/tests_app/components/database/test_client_server.py @@ -54,7 +54,7 @@ def test_client_server(): secrets = [Secret(name="example", value="secret")] - general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a") + general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a") # noqa: S106 assert general.cls_name == "TestConfig" assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}' diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index 6865f91bc5..f7704af7b7 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -1,3 +1,4 @@ +import contextlib import os import shutil import signal @@ -37,7 +38,7 @@ def pytest_sessionfinish(session, exitstatus): # TODO this isn't great. We should have each tests doing it's own cleanup current_process = psutil.Process() for child in current_process.children(recursive=True): - try: + with contextlib.suppress(psutil.NoSuchProcess): params = child.as_dict() or {} cmd_lines = params.get("cmdline", []) # we shouldn't kill the resource tracker from multiprocessing. If we do, @@ -45,8 +46,6 @@ def pytest_sessionfinish(session, exitstatus): if cmd_lines and "resource_tracker" in cmd_lines[-1]: continue child.kill() - except psutil.NoSuchProcess: - pass main_thread = threading.current_thread() for t in threading.enumerate(): diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index 428c6fb957..006e5a4516 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -1,3 +1,4 @@ +import contextlib import logging import os import pickle @@ -555,12 +556,11 @@ class CheckpointLightningApp(LightningApp): def test_snap_shotting(): - try: + with contextlib.suppress(SuccessException): app = CheckpointLightningApp(FlowA()) app.checkpointing = True MultiProcessRuntime(app, start_server=False).dispatch() - except SuccessException: - pass + checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints") checkpoints = os.listdir(checkpoint_dir) assert len(checkpoints) == 1 diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index 62a787ea3c..ef28d29b06 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -1,3 +1,4 @@ +import contextlib import os import pickle from collections import Counter @@ -374,11 +375,9 @@ def test_lightning_flow_and_work(): "changes": {}, } assert flow_a.state == state - try: + with contextlib.suppress(ExitAppException): while True: flow_a.run() - except ExitAppException: - pass state = { "vars": {"counter": 5, "_layout": ANY, "_paths": {}}, diff --git a/tests/tests_app/plugin/test_plugin.py b/tests/tests_app/plugin/test_plugin.py index 25ef908515..81983ff91b 100644 --- a/tests/tests_app/plugin/test_plugin.py +++ b/tests/tests_app/plugin/test_plugin.py @@ -24,7 +24,7 @@ def mock_plugin_server(mock_uvicorn) -> TestClient: mock_uvicorn.run.side_effect = create_test_client - _start_plugin_server("0.0.0.0", 8888) + _start_plugin_server("0.0.0.0", 8888) # noqa: S104 return test_client["client"] diff --git a/tests/tests_app/runners/test_multiprocess.py b/tests/tests_app/runners/test_multiprocess.py index 36abb8f1ba..8aa532d26f 100644 --- a/tests/tests_app/runners/test_multiprocess.py +++ b/tests/tests_app/runners/test_multiprocess.py @@ -53,7 +53,7 @@ class StartFrontendServersTestFlow(LightningFlow): "cloudspace_host, port, expected_host, expected_target", [ (None, 7000, "localhost", "http://localhost:7000"), - ("test.lightning.ai", 7000, "0.0.0.0", "https://7000-test.lightning.ai"), + ("test.lightning.ai", 7000, "0.0.0.0", "https://7000-test.lightning.ai"), # noqa: S104 ], ) @mock.patch("lightning.app.runners.multiprocess.find_free_network_port") diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index d1ac86e42d..6f2c6f63f5 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -267,9 +267,9 @@ class WorkRunnerPatch(WorkRunner): ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state, verbose_level=2))) ) self.counter += 1 - except Exception as e: + except Exception as ex: logger.error(traceback.format_exc()) - self.error_queue.put(e) + self.error_queue.put(ex) raise ExitAppException @@ -435,7 +435,6 @@ def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get): copy_request_queue=_MockQueue(), copy_response_queue=_MockQueue(), ) - with contextlib.suppress(ExitAppException): runner() @@ -645,7 +644,7 @@ def test_state_observer(): "patch_constants, environment, expected_ip_addr", [ ({}, {}, "127.0.0.1"), - ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), + ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), # noqa: S104 ({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"), ], indirect=["patch_constants"], diff --git a/tests/tests_fabric/helpers/runif.py b/tests/tests_fabric/helpers/runif.py index bf26270e26..b133e86e87 100644 --- a/tests/tests_fabric/helpers/runif.py +++ b/tests/tests_fabric/helpers/runif.py @@ -98,12 +98,12 @@ class RunIf: if bf16_cuda: try: cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) - except (AssertionError, RuntimeError) as e: + except (AssertionError, RuntimeError) as ex: # AssertionError: Torch not compiled with CUDA enabled # RuntimeError: Found no NVIDIA driver on your system. - is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e) + is_unrelated = "Found no NVIDIA driver" not in str(ex) or "Torch not compiled with CUDA" not in str(ex) if is_unrelated: - raise e + raise ex cond = True conditions.append(cond) diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 9271fb590b..bab2744464 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -77,13 +77,13 @@ class _MyFabricGradNorm(BoringFabric): try: super().run() break - except RuntimeError as e: + except RuntimeError as ex: # nonfinite grads -> skip and continue # this may repeat until the scaler finds a factor where overflow is avoided, # so the while loop should eventually break # stop after a max of 10 tries - if i > 10 or not str(e).startswith("The total norm"): - raise e + if i > 10 or not str(ex).startswith("The total norm"): + raise ex # unscale was already called by last attempt, # but no update afterwards since optimizer step was missing. @@ -117,13 +117,13 @@ class _MyFabricGradVal(BoringFabric): try: super().run() break - except RuntimeError as e: + except RuntimeError as ex: # nonfinite grads -> skip and continue # this may repeat until the scaler finds a factor where overflow is avoided, # so the while loop should eventually break # stop after a max of 10 tries - if i > 10 or not str(e).startswith("Nonfinite grads"): - raise e + if i > 10 or not str(ex).startswith("Nonfinite grads"): + raise ex # unscale was already called by last attempt, # but no update afterwards since optimizer step was missing. diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 58fba175da..e2954b0d73 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -103,7 +103,7 @@ class MNIST(Dataset): for url in self.RESOURCES: logging.info(f"Downloading {url}") fpath = os.path.join(data_folder, os.path.basename(url)) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod def _try_load(path_data, trials: int = 30, delta: float = 1.0): @@ -115,8 +115,8 @@ class MNIST(Dataset): try: res = torch.load(path_data) # todo: specify the possible exception - except Exception as e: - exception = e + except Exception as ex: + exception = ex time.sleep(delta * random.random()) else: break