adding check for bandit vulnerabilities 1/n (#17382)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
26666cb568
commit
156786343b
|
@ -68,6 +68,9 @@ repos:
|
|||
- id: yesqa
|
||||
name: Unused noqa
|
||||
additional_dependencies:
|
||||
#- pep8-naming
|
||||
#- flake8-pytest-style
|
||||
- flake8-bandit
|
||||
- flake8-simplify
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
|
|
|
@ -7,7 +7,7 @@ print("Starting processing ...")
|
|||
scaler = MinMaxScaler()
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
df_data.values, df_target.values, test_size=0.20, random_state=random.randint(0, 42) # noqa F821
|
||||
df_data.values, df_target.values, test_size=0.20, random_state=random.randint(0, 42)
|
||||
)
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
X_test = scaler.transform(X_test)
|
||||
|
|
|
@ -376,11 +376,11 @@ class MyCustomTrainer:
|
|||
|
||||
try:
|
||||
monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])]
|
||||
except KeyError as e:
|
||||
except KeyError as ex:
|
||||
possible_keys = list(possible_monitor_vals.keys())
|
||||
raise KeyError(
|
||||
f"monitor {scheduler_cfg['monitor']} is invalid. Possible values are {possible_keys}."
|
||||
) from e
|
||||
) from ex
|
||||
|
||||
# rely on model hook for actual step
|
||||
model.lr_scheduler_step(scheduler_cfg["scheduler"], monitor)
|
||||
|
|
|
@ -55,6 +55,7 @@ line-length = 120
|
|||
select = [
|
||||
"E", "W", # see: https://pypi.org/project/pycodestyle
|
||||
"F", # see: https://pypi.org/project/pyflakes
|
||||
"S", # see: https://pypi.org/project/flake8-bandit
|
||||
]
|
||||
extend-select = [
|
||||
"C4", # see: https://pypi.org/project/flake8-comprehensions
|
||||
|
@ -71,6 +72,34 @@ exclude = [
|
|||
]
|
||||
ignore-init-module-imports = true
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
".actions/*" = ["S101", "S310"]
|
||||
"setup.py" = ["S101"]
|
||||
"examples/**" = [
|
||||
"S101", # Use of `assert` detected
|
||||
"S113", # todo: Probable use of requests call without
|
||||
"S104", # Possible binding to all interface
|
||||
"F821", # Undefined name `...`
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"S501", # Probable use of `requests` call with `verify=False` disabling SSL certificate checks
|
||||
"S108", # Probable insecure usage of temporary file or directory: "/tmp/data/MNIST"
|
||||
]
|
||||
"src/**" = [
|
||||
"S101", # todo: Use of `assert` detected
|
||||
"S105", "S106", "S107", # todo: Possible hardcoded password: ...
|
||||
"S113", # todo: Probable use of requests call without timeout
|
||||
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
|
||||
"S324", # todo: Probable use of insecure hash functions in `hashlib`
|
||||
]
|
||||
"tests/**" = [
|
||||
"S101", # Use of `assert` detected
|
||||
"S105", "S106", # todo: Possible hardcoded password: ...
|
||||
"S301", # `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
|
||||
"S113", # todo: Probable use of requests call without timeout
|
||||
"S311", # todo: Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"S108", # todo: Probable insecure usage of temporary file or directory: "/tmp/sys-customizations-sync"
|
||||
]
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 10
|
||||
|
|
|
@ -35,7 +35,7 @@ def replace(req: str, torch_version: Optional[str] = None) -> str:
|
|||
import torch
|
||||
|
||||
torch_version = torch.__version__
|
||||
assert torch_version, f"invalid torch: {torch_version}"
|
||||
assert torch_version, f"invalid torch: {torch_version}" # noqa: S101
|
||||
|
||||
# remove comments and strip whitespace
|
||||
req = re.sub(rf"\s*#.*{os.linesep}", os.linesep, req).strip()
|
||||
|
|
|
@ -325,8 +325,8 @@ def _wait_for_cluster_state(
|
|||
break
|
||||
time.sleep(poll_duration_seconds)
|
||||
elapsed = int(time.time() - start)
|
||||
except lightning_cloud.openapi.rest.ApiException as e:
|
||||
if e.status == 404 and target_state == V1ClusterState.DELETED:
|
||||
except lightning_cloud.openapi.rest.ApiException as ex:
|
||||
if ex.status == 404 and target_state == V1ClusterState.DELETED:
|
||||
return
|
||||
raise
|
||||
else:
|
||||
|
|
|
@ -113,10 +113,11 @@ def _capture_valid_app_component_name(value: Optional[str] = None, resource_type
|
|||
raise SystemExit(m)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
m = f"""
|
||||
raise SystemExit(
|
||||
f"""
|
||||
⚡ {resource_type} init aborted! ⚡
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
|
|
@ -227,13 +227,14 @@ def _show_install_component_prompt(entry: Dict[str, str], component: str, org: s
|
|||
return git_url
|
||||
except KeyboardInterrupt:
|
||||
repo = entry["sourceUrl"]
|
||||
m = f"""
|
||||
raise SystemExit(
|
||||
f"""
|
||||
⚡ Installation aborted! ⚡
|
||||
|
||||
Install the component yourself by visiting:
|
||||
{repo}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
|
||||
|
||||
def _show_non_gallery_install_component_prompt(gh_url: str, yes_arg: bool) -> str:
|
||||
|
@ -282,13 +283,14 @@ def _show_non_gallery_install_component_prompt(gh_url: str, yes_arg: bool) -> st
|
|||
|
||||
return gh_url
|
||||
except KeyboardInterrupt:
|
||||
m = f"""
|
||||
raise SystemExit(
|
||||
f"""
|
||||
⚡ Installation aborted! ⚡
|
||||
|
||||
Install the component yourself by visiting:
|
||||
{repo_url}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
|
||||
|
||||
def _show_install_app_prompt(
|
||||
|
@ -332,13 +334,14 @@ def _show_install_app_prompt(
|
|||
return source_url, git_url, folder_name, git_sha
|
||||
except KeyboardInterrupt:
|
||||
repo = entry["sourceUrl"]
|
||||
m = f"""
|
||||
raise SystemExit(
|
||||
f"""
|
||||
⚡ Installation aborted! ⚡
|
||||
|
||||
Install the {resource_type} yourself by visiting:
|
||||
{repo}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
|
||||
|
||||
def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[str, str]:
|
||||
|
@ -352,15 +355,16 @@ def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[st
|
|||
folder_name = gh_url.split("/")[-1]
|
||||
|
||||
org = re.search(r"github.com\/(.*)\/", gh_url).group(1) # type: ignore
|
||||
except Exception as e: # noqa
|
||||
m = """
|
||||
except Exception:
|
||||
raise SystemExit(
|
||||
"""
|
||||
Your github url is not supported. Here's the supported format:
|
||||
https://github.com/YourOrgName/your-repo-name
|
||||
|
||||
Example:
|
||||
https://github.com/Lightning-AI/lightning
|
||||
"""
|
||||
raise SystemExit("")
|
||||
)
|
||||
|
||||
# yes arg does not prompt the user for permission to install anything
|
||||
# automatically creates env and sets up the project
|
||||
|
@ -396,20 +400,22 @@ def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[st
|
|||
|
||||
return gh_url, folder_name
|
||||
except KeyboardInterrupt:
|
||||
m = f"""
|
||||
raise SystemExit(
|
||||
f"""
|
||||
⚡ Installation aborted! ⚡
|
||||
|
||||
Install the app yourself by visiting {gh_url}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
|
||||
|
||||
def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, str]:
|
||||
# ensure resource identifier is properly formatted
|
||||
try:
|
||||
org, resource = name.split("/")
|
||||
except Exception as e: # noqa
|
||||
m = f"""
|
||||
except Exception:
|
||||
raise SystemExit(
|
||||
f"""
|
||||
{resource_type} name format must have organization/{resource_type}-name
|
||||
|
||||
Examples:
|
||||
|
@ -418,12 +424,7 @@ def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, st
|
|||
|
||||
You passed in: {name}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
m = f"""
|
||||
⚡ Installing Lightning {resource_type} ⚡
|
||||
{resource_type} name: {resource}
|
||||
developer: {org}
|
||||
"""
|
||||
)
|
||||
return org, resource
|
||||
|
||||
|
||||
|
@ -466,13 +467,14 @@ def _resolve_resource(
|
|||
elif resource_type == "component":
|
||||
gallery_entries = data["components"]
|
||||
except requests.ConnectionError:
|
||||
m = f"""
|
||||
sys.tracebacklimit = 0
|
||||
raise SystemError(
|
||||
f"""
|
||||
Network connection error, could not load list of available Lightning {resource_type}s.
|
||||
|
||||
Try again when you have a network connection!
|
||||
"""
|
||||
sys.tracebacklimit = 0
|
||||
raise SystemError(m)
|
||||
)
|
||||
|
||||
entries = []
|
||||
all_versions = []
|
||||
|
@ -582,15 +584,16 @@ def _install_app_from_source(
|
|||
logger.info(f"⚡ RUN: git clone {source_url}")
|
||||
try:
|
||||
subprocess.check_output(["git", "clone", git_url], stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if "Repository not found" in str(e.output):
|
||||
m = f"""
|
||||
except subprocess.CalledProcessError as ex:
|
||||
if "Repository not found" in str(ex.output):
|
||||
raise SystemExit(
|
||||
f"""
|
||||
Looks like the github url was not found or doesn't exist. Do you have a typo?
|
||||
{source_url}
|
||||
"""
|
||||
raise SystemExit(m)
|
||||
)
|
||||
else:
|
||||
raise Exception(e)
|
||||
raise Exception(ex)
|
||||
|
||||
# step into the repo folder
|
||||
os.chdir(f"{folder_name}")
|
||||
|
@ -599,11 +602,10 @@ def _install_app_from_source(
|
|||
try:
|
||||
if git_sha:
|
||||
subprocess.check_output(["git", "checkout", git_sha], stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if "did not match any" in str(e.output):
|
||||
except subprocess.CalledProcessError as ex:
|
||||
if "did not match any" in str(ex.output):
|
||||
raise SystemExit("Looks like the git SHA is not valid or doesn't exist in app repo.")
|
||||
else:
|
||||
raise Exception(e)
|
||||
raise Exception(ex)
|
||||
|
||||
# activate and install reqs
|
||||
# TODO: remove shell=True... but need to run command in venv
|
||||
|
|
|
@ -119,14 +119,14 @@ def download_frontend(destination: Path) -> None:
|
|||
url = "https://storage.googleapis.com/grid-packages/pytorch-lightning-app/v0.0.0/build.tar.gz"
|
||||
build_dir_name = "build"
|
||||
with TemporaryDirectory() as download_dir:
|
||||
response = urllib.request.urlopen(url)
|
||||
response = urllib.request.urlopen(url) # noqa: S310
|
||||
file = tarfile.open(fileobj=response, mode="r|gz")
|
||||
file.extractall(path=download_dir)
|
||||
shutil.move(str(Path(download_dir, build_dir_name)), destination)
|
||||
|
||||
|
||||
def project_file_from_template(template_dir: Path, destination_dir: Path, template_name: str, **kwargs: Any) -> None:
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
env = Environment(loader=FileSystemLoader(template_dir)) # noqa: S701
|
||||
template = env.get_template(template_name)
|
||||
rendered_template = template.render(**kwargs)
|
||||
with open(destination_dir / template_name, "w") as file:
|
||||
|
|
|
@ -59,7 +59,8 @@ class _SSHKeyManager:
|
|||
console.print(ssh_keys.as_table())
|
||||
|
||||
def add_key(self, public_key: str, name: Optional[str], comment: Optional[str]) -> None:
|
||||
key_name = name if name is not None else "-".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||
rnd = "-".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311
|
||||
key_name = name if name is not None else rnd
|
||||
self.api_client.s_sh_public_key_service_create_ssh_public_key(
|
||||
V1CreateSSHPublicKeyRequest(
|
||||
name=key_name,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import concurrent
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
|
@ -253,8 +254,8 @@ def _download_file(path: str, url: str, progress: Progress, task_id: TaskID) ->
|
|||
# Disable warning about making an insecure request
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
try:
|
||||
request = requests.get(url, stream=True, verify=False)
|
||||
with contextlib.suppress(ConnectionError):
|
||||
request = requests.get(url, stream=True, verify=False) # noqa: S501
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
|
@ -262,8 +263,6 @@ def _download_file(path: str, url: str, progress: Progress, task_id: TaskID) ->
|
|||
for chunk in request.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk) # type: ignore
|
||||
progress.update(task_id, advance=len(chunk))
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
|
||||
def _sanitize_path(path: str, pwd: str) -> Tuple[str, bool]:
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
@ -224,7 +224,9 @@ def _collect_artifacts(
|
|||
if page_token in tokens:
|
||||
return
|
||||
|
||||
try:
|
||||
# Note: This is triggered when the request is wrong.
|
||||
# This is currently happening due to looping through the user clusters.
|
||||
with contextlib.suppress(lightning_cloud.openapi.rest.ApiException):
|
||||
response = client.lightningapp_instance_service_list_project_artifacts(
|
||||
project_id,
|
||||
prefix=prefix,
|
||||
|
@ -248,10 +250,6 @@ def _collect_artifacts(
|
|||
page_token=response.next_page_token,
|
||||
tokens=tokens,
|
||||
)
|
||||
except lightning_cloud.openapi.rest.ApiException:
|
||||
# Note: This is triggered when the request is wrong.
|
||||
# This is currently happening due to looping through the user clusters.
|
||||
pass
|
||||
|
||||
|
||||
def _add_resource_prefix(prefix: str, resource_path: str):
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import click
|
||||
|
@ -83,7 +83,7 @@ def rm(rm_path: str, r: bool = False, recursive: bool = False) -> None:
|
|||
succeeded = False
|
||||
|
||||
for cluster in clusters.clusters:
|
||||
try:
|
||||
with contextlib.suppress(lightning_cloud.openapi.rest.ApiException):
|
||||
client.lightningapp_instance_service_delete_project_artifact(
|
||||
project_id=project_id,
|
||||
cluster_id=cluster.cluster_id,
|
||||
|
@ -91,8 +91,6 @@ def rm(rm_path: str, r: bool = False, recursive: bool = False) -> None:
|
|||
)
|
||||
succeeded = True
|
||||
break
|
||||
except lightning_cloud.openapi.rest.ApiException:
|
||||
pass
|
||||
|
||||
prefix = os.path.join(*splits)
|
||||
|
||||
|
|
|
@ -102,9 +102,9 @@ def connect_maverick(name: str, project_name: str = "") -> None:
|
|||
with Live(Spinner("point", text=Text("Registering maverick...", style="white")), transient=True) as live:
|
||||
try:
|
||||
register_to_cloud(name, project_name)
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
live.stop()
|
||||
rich.print(f"[red]Failed[/red]: Registering maverick failed with error {e}")
|
||||
rich.print(f"[red]Failed[/red]: Registering maverick failed with error {ex}")
|
||||
return
|
||||
|
||||
live.update(Spinner("point", text=Text("Setting up ...", style="white")))
|
||||
|
@ -209,8 +209,8 @@ def disconnect_maverick(name: str) -> None:
|
|||
with Live(Spinner("point", text=Text("disconnecting maverick...", style="white")), transient=True):
|
||||
try:
|
||||
deregister_from_cloud(name)
|
||||
except Exception as e:
|
||||
rich.print(f"[red]Failed[/red]: Disconnecting machine failed with error: {e}")
|
||||
except Exception as ex:
|
||||
rich.print(f"[red]Failed[/red]: Disconnecting machine failed with error: {ex}")
|
||||
return
|
||||
subprocess.run(f"docker stop {CODE_SERVER_CONTAINER}", shell=True, capture_output=True)
|
||||
subprocess.run(f"docker stop {LIGHTNING_DAEMON_CONTAINER}", shell=True, capture_output=True)
|
||||
|
|
|
@ -244,8 +244,8 @@ def cluster_logs(cluster_id: str, to_time: arrow.Arrow, from_time: arrow.Arrow,
|
|||
rich.print(f"[{color}]{log_event.labels.level:5}[/{color}] {date} {log_event.message.rstrip()}")
|
||||
except LogLinesLimitExceeded:
|
||||
raise click.ClickException(f"Read {limit} log lines, but there may be more. Use --limit param to read more")
|
||||
except Exception as error:
|
||||
logger.error(f"⚡ Error while reading logs ({type(error)}), {error}", exc_info=DEBUG)
|
||||
except Exception as ex:
|
||||
logger.error(f"⚡ Error while reading logs ({type(ex)}), {ex}", exc_info=DEBUG)
|
||||
|
||||
|
||||
@_main.command()
|
||||
|
|
|
@ -117,9 +117,9 @@ def add_ssh_key(
|
|||
new_public_key = Path(str(public_key)).read_text() if os.path.isfile(str(public_key)) else public_key
|
||||
try:
|
||||
ssh_key_manager.add_key(name=key_name, comment=comment, public_key=str(new_public_key))
|
||||
except ApiException as e:
|
||||
except ApiException as ex:
|
||||
# if we got an exception it might be the user passed the private key file
|
||||
if os.path.isfile(str(public_key)) and os.path.isfile(f"{public_key}.pub"):
|
||||
ssh_key_manager.add_key(name=key_name, comment=comment, public_key=Path(f"{public_key}.pub").read_text())
|
||||
else:
|
||||
raise e
|
||||
raise ex
|
||||
|
|
|
@ -211,12 +211,12 @@ def delete_app(app_name: str, cluster_id: str, skip_user_confirm_prompt: bool) -
|
|||
# Delete the app!
|
||||
app_manager = _AppManager()
|
||||
app_manager.delete(app_id=selected_app_instance_id)
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
console.print(
|
||||
f'[b][red]An issue occurred while deleting app "{app_name}. If the issue persists, please '
|
||||
"reach out to us at [link=mailto:support@lightning.ai]support@lightning.ai[/link][/b][/red]."
|
||||
)
|
||||
raise click.ClickException(str(e))
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
console.print(f'[b][green]App "{app_name}" has been successfully deleted from cluster "{cluster_id}"![/green][/b]')
|
||||
return
|
||||
|
|
|
@ -258,5 +258,5 @@ def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: boo
|
|||
logger.debug(f"Creating the following tables {models}")
|
||||
try:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
except Exception as ex:
|
||||
logger.debug(ex)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
|
@ -124,23 +125,19 @@ class UIRefresher(Thread):
|
|||
raise ex
|
||||
|
||||
def run_once(self) -> None:
|
||||
try:
|
||||
with contextlib.suppress(queue.Empty):
|
||||
global app_status
|
||||
state, app_status = self.api_publish_state_queue.get(timeout=0)
|
||||
with lock:
|
||||
global_app_state_store.set_app_state(TEST_SESSION_UUID, state)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
try:
|
||||
with contextlib.suppress(queue.Empty):
|
||||
responses = self.api_response_queue.get(timeout=0)
|
||||
with lock:
|
||||
# TODO: Abstract the responses store to support horizontal scaling.
|
||||
global responses_store
|
||||
for response in responses:
|
||||
responses_store[response["id"]] = response["response"]
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def join(self, timeout: Optional[float] = None) -> None:
|
||||
self._exit_event.set()
|
||||
|
@ -179,7 +176,7 @@ fastapi_service.add_middleware(
|
|||
)
|
||||
|
||||
if _is_starsessions_available():
|
||||
fastapi_service.add_middleware(SessionMiddleware, secret_key="secret", autoload=True)
|
||||
fastapi_service.add_middleware(SessionMiddleware, secret_key="secret", autoload=True) # noqa: S106
|
||||
|
||||
|
||||
# General sequence is:
|
||||
|
@ -490,7 +487,7 @@ def start_server(
|
|||
|
||||
if uvicorn_run:
|
||||
host = host.split("//")[-1] if "//" in host else host
|
||||
if host == "0.0.0.0":
|
||||
if host == "0.0.0.0": # noqa: S104
|
||||
logger.info("Your app has started.")
|
||||
else:
|
||||
logger.info(f"Your app has started. View it in your browser: http://{host}:{port}/view")
|
||||
|
|
|
@ -355,8 +355,8 @@ class LightningApp:
|
|||
work = None
|
||||
try:
|
||||
work = self.get_component_by_name(delta.id)
|
||||
except (KeyError, AttributeError) as e:
|
||||
logger.error(f"The component {delta.id} couldn't be accessed. Exception: {e}")
|
||||
except (KeyError, AttributeError) as ex:
|
||||
logger.error(f"The component {delta.id} couldn't be accessed. Exception: {ex}")
|
||||
|
||||
if work:
|
||||
delta = _delta_to_app_state_delta(
|
||||
|
@ -421,8 +421,8 @@ class LightningApp:
|
|||
for delta in deltas:
|
||||
try:
|
||||
state += delta
|
||||
except Exception as e:
|
||||
raise Exception(f"Current State {state}, {delta.to_dict()}") from e
|
||||
except Exception as ex:
|
||||
raise Exception(f"Current State {state}, {delta.to_dict()}") from ex
|
||||
|
||||
# new_state = self.populate_changes(self.last_state, state)
|
||||
self.set_state(state)
|
||||
|
|
|
@ -93,7 +93,7 @@ ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "1")))
|
|||
ENABLE_UPLOAD_ENDPOINT = bool(int(os.getenv("ENABLE_UPLOAD_ENDPOINT", "1")))
|
||||
|
||||
# directory where system customization sync files stored
|
||||
SYS_CUSTOMIZATIONS_SYNC_ROOT = "/tmp/sys-customizations-sync"
|
||||
SYS_CUSTOMIZATIONS_SYNC_ROOT = "/tmp/sys-customizations-sync" # noqa: S108 # todo
|
||||
# directory where system customization sync files will be copied to be packed into app tarball
|
||||
SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"
|
||||
|
||||
|
|
|
@ -419,7 +419,6 @@ class HTTPQueue(BaseQueue):
|
|||
# than the default timeout
|
||||
if timeout > self.default_timeout:
|
||||
time.sleep(0.05)
|
||||
pass
|
||||
|
||||
def _get(self) -> Any:
|
||||
try:
|
||||
|
|
|
@ -483,10 +483,10 @@ class LightningWork:
|
|||
def __getattribute__(self, name: str) -> Any:
|
||||
try:
|
||||
attr = object.__getattribute__(self, name)
|
||||
except AttributeError as e:
|
||||
if str(e).endswith("'_state'"):
|
||||
except AttributeError as ex:
|
||||
if str(ex).endswith("'_state'"):
|
||||
raise AttributeError(f"Did you forget to call super().__init__() in {self}")
|
||||
raise e
|
||||
raise ex
|
||||
|
||||
if isinstance(attr, ProxyWorkRun):
|
||||
return attr
|
||||
|
|
|
@ -128,28 +128,28 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
|
|||
|
||||
with open(download_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error downloading plugin source: {str(e)}.",
|
||||
detail=f"Error downloading plugin source: {str(ex)}.",
|
||||
)
|
||||
|
||||
# Extract
|
||||
try:
|
||||
with tarfile.open(download_path, "r:gz") as tf:
|
||||
tf.extractall(source_path)
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error extracting plugin source: {str(e)}.",
|
||||
detail=f"Error extracting plugin source: {str(ex)}.",
|
||||
)
|
||||
|
||||
# Import the plugin
|
||||
try:
|
||||
plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint))
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(e)}."
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(ex)}."
|
||||
)
|
||||
|
||||
# Ensure that apps are dispatched from the temp directory
|
||||
|
@ -165,9 +165,9 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
|
|||
)
|
||||
actions = plugin.run(**run.plugin_arguments) or []
|
||||
return {"actions": [action.to_spec().to_dict() for action in actions]}
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(ex)}."
|
||||
)
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
|
|
@ -80,7 +80,7 @@ class MultiProcessingBackend(Backend):
|
|||
if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
|
||||
# Override the port if set by the user
|
||||
work._port = find_free_network_port()
|
||||
work._host = "0.0.0.0"
|
||||
work._host = "0.0.0.0" # noqa: S104
|
||||
work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
|
||||
|
||||
app.processes[work.name] = MultiProcessWorkManager(app, work)
|
||||
|
@ -121,7 +121,7 @@ class CloudMultiProcessingBackend(MultiProcessingBackend):
|
|||
self.ports = []
|
||||
|
||||
def create_work(self, app, work) -> None:
|
||||
work._host = "0.0.0.0"
|
||||
work._host = "0.0.0.0" # noqa: S104
|
||||
nc = enable_port()
|
||||
self.ports.append(nc.port)
|
||||
work._port = nc.port
|
||||
|
|
|
@ -186,8 +186,8 @@ class CloudRuntime(Runtime):
|
|||
if "PYTEST_CURRENT_TEST" not in os.environ:
|
||||
click.launch(self._get_cloudspace_url(project, cloudspace_name, "code", needs_credits))
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(e.body)
|
||||
except ApiException as ex:
|
||||
logger.error(ex.body)
|
||||
sys.exit(1)
|
||||
|
||||
def cloudspace_dispatch(
|
||||
|
@ -384,8 +384,8 @@ class CloudRuntime(Runtime):
|
|||
if bool(int(os.getenv("LIGHTING_TESTING", "0"))):
|
||||
print(f"APP_LOGS_URL: {self._get_app_url(project, run_instance, 'logs')}")
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(e.body)
|
||||
except ApiException as ex:
|
||||
logger.error(ex.body)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if cleanup_handle:
|
||||
|
@ -399,8 +399,8 @@ class CloudRuntime(Runtime):
|
|||
|
||||
try:
|
||||
app = load_app_from_file(filepath, raise_exception=True, mock_imports=True, env_vars=env_vars)
|
||||
except FileNotFoundError as e:
|
||||
raise e
|
||||
except FileNotFoundError as ex:
|
||||
raise ex
|
||||
except Exception:
|
||||
from lightning.app.testing.helpers import EmptyFlow
|
||||
|
||||
|
@ -770,7 +770,7 @@ class CloudRuntime(Runtime):
|
|||
if cloudspace is not None and cloudspace.code_config is not None:
|
||||
data_connection_mounts = cloudspace.code_config.data_connection_mounts
|
||||
|
||||
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
|
||||
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311
|
||||
work_spec = V1LightningworkSpec(
|
||||
build_spec=build_spec,
|
||||
drives=drives + mounts,
|
||||
|
|
|
@ -51,7 +51,7 @@ class MultiProcessRuntime(Runtime):
|
|||
|
||||
# Note: In case the runtime is used in the cloud.
|
||||
in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None
|
||||
self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host
|
||||
self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host # noqa: S104
|
||||
|
||||
self.app.backend = self.backend
|
||||
self.backend._prepare_queues(self.app)
|
||||
|
@ -68,7 +68,7 @@ class MultiProcessRuntime(Runtime):
|
|||
for frontend in self.app.frontends.values():
|
||||
port = find_free_network_port()
|
||||
|
||||
server_host = "0.0.0.0" if in_cloudspace else "localhost"
|
||||
server_host = "0.0.0.0" if in_cloudspace else "localhost" # noqa: S104
|
||||
server_target = (
|
||||
f"https://{port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
|
||||
if in_cloudspace
|
||||
|
|
|
@ -38,7 +38,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int
|
|||
if algorithm == "blake2":
|
||||
h = hashlib.blake2b(digest_size=20)
|
||||
elif algorithm == "md5":
|
||||
h = hashlib.md5()
|
||||
h = hashlib.md5() # noqa: S324
|
||||
else:
|
||||
raise ValueError(f"Algorithm {algorithm} not supported")
|
||||
|
||||
|
|
|
@ -128,9 +128,9 @@ def _copy_files(
|
|||
fs.makedirs(str(to_path.parent), exist_ok=True)
|
||||
|
||||
fs.put(str(from_path), str(to_path), recursive=False)
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
# Return the exception so that it can be handled in the main thread
|
||||
return e
|
||||
return ex
|
||||
|
||||
# NOTE: Cannot use `S3FileSystem.put(recursive=True)` because it tries to access parent directories
|
||||
# which it does not have access to.
|
||||
|
|
|
@ -129,7 +129,7 @@ class Path(PathlibPath):
|
|||
if self._origin is None:
|
||||
return None
|
||||
contents = f"{self.origin_name}/{self}"
|
||||
return hashlib.sha1(contents.encode("utf-8")).hexdigest()
|
||||
return hashlib.sha1(contents.encode("utf-8")).hexdigest() # noqa: S324
|
||||
|
||||
@property
|
||||
def parents(self) -> Sequence["Path"]:
|
||||
|
@ -363,8 +363,8 @@ class Path(PathlibPath):
|
|||
try:
|
||||
_copy_files(source_path, destination_path)
|
||||
_logger.debug(f"All files copied from {request.path} to {response.path}.")
|
||||
except Exception as e:
|
||||
response.exception = e
|
||||
except Exception as ex:
|
||||
response.exception = ex
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class _BasePayload(ABC):
|
|||
if self._origin is None:
|
||||
return None
|
||||
contents = f"{self.origin_name}/{self.consumer_name}/{self.name}"
|
||||
return hashlib.sha1(contents.encode("utf-8")).hexdigest()
|
||||
return hashlib.sha1(contents.encode("utf-8")).hexdigest() # noqa: S324
|
||||
|
||||
@property
|
||||
def origin_name(self) -> str:
|
||||
|
@ -251,8 +251,8 @@ class _BasePayload(ABC):
|
|||
response.size = source_path.stat().st_size
|
||||
_copy_files(source_path, destination_path)
|
||||
_logger.debug(f"All files copied from {request.path} to {response.path}.")
|
||||
except Exception as e:
|
||||
response.exception = e
|
||||
except Exception as ex:
|
||||
response.exception = ex
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
|
@ -209,14 +209,12 @@ class _LightningAppOpenAPIRetriever:
|
|||
|
||||
# 2: If no identifier has been provided, evaluate the local application
|
||||
if self.app_id_or_name_or_url is None:
|
||||
try:
|
||||
with contextlib.suppress(requests.exceptions.ConnectionError):
|
||||
self.url = f"http://localhost:{APP_SERVER_PORT}"
|
||||
resp = requests.get(f"{self.url}/openapi.json")
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
|
||||
self.openapi = resp.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
|
||||
# 3: If an identified was provided or the local evaluation has failed, evaluate the cloud.
|
||||
else:
|
||||
|
|
|
@ -48,4 +48,4 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str:
|
|||
if len(clusters) == 0:
|
||||
raise RuntimeError(f"No clusters found on `{client.api_client.configuration.host}`.")
|
||||
|
||||
return random.choice(clusters).id
|
||||
return random.choice(clusters).id # noqa: S311
|
||||
|
|
|
@ -44,9 +44,9 @@ def makedirs(path: str):
|
|||
r"""Recursive directory creation function."""
|
||||
try:
|
||||
os.makedirs(osp.expanduser(osp.normpath(path)))
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST and osp.isdir(path):
|
||||
raise e
|
||||
except OSError as ex:
|
||||
if ex.errno != errno.EEXIST and osp.isdir(path):
|
||||
raise ex
|
||||
|
||||
|
||||
class ClientCommand:
|
||||
|
@ -226,9 +226,9 @@ def _process_api_request(app, request: _APIRequest):
|
|||
method = getattr(flow, request.method_name)
|
||||
try:
|
||||
response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
|
||||
except HTTPException as e:
|
||||
logger.error(repr(e))
|
||||
response = _RequestResponse(status_code=e.status_code, content=e.detail)
|
||||
except HTTPException as ex:
|
||||
logger.error(repr(ex))
|
||||
response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
|
||||
except Exception:
|
||||
logger.error(traceback.print_exc())
|
||||
response = _RequestResponse(status_code=500)
|
||||
|
@ -244,9 +244,9 @@ def _process_command_requests(app, request: _CommandRequest):
|
|||
# Validation is done on the CLI side.
|
||||
try:
|
||||
response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
|
||||
except HTTPException as e:
|
||||
logger.error(repr(e))
|
||||
response = _RequestResponse(status_code=e.status_code, content=e.detail)
|
||||
except HTTPException as ex:
|
||||
logger.error(repr(ex))
|
||||
response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
|
||||
except Exception:
|
||||
logger.error(traceback.print_exc())
|
||||
response = _RequestResponse(status_code=500)
|
||||
|
|
|
@ -38,7 +38,7 @@ def _add_comment_to_literal_code(method, contains, comment):
|
|||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e: # noqa
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
@ -74,12 +74,12 @@ def _load_objects_from_file(
|
|||
with _add_to_env(env_vars), _patch_sys_argv():
|
||||
if mock_imports:
|
||||
with _mock_missing_imports():
|
||||
exec(code, module.__dict__)
|
||||
exec(code, module.__dict__) # noqa: S102
|
||||
else:
|
||||
exec(code, module.__dict__)
|
||||
except Exception as e:
|
||||
exec(code, module.__dict__) # noqa: S102
|
||||
except Exception as ex:
|
||||
if raise_exception:
|
||||
raise e
|
||||
raise ex
|
||||
_prettifiy_exception(filepath)
|
||||
|
||||
return [v for v in module.__dict__.values() if isinstance(v, target_type)], module
|
||||
|
|
|
@ -161,13 +161,13 @@ class AuthServer:
|
|||
try:
|
||||
# check if server is reachable or catch any network errors
|
||||
requests.head(url)
|
||||
except requests.ConnectionError as e:
|
||||
except requests.ConnectionError as ex:
|
||||
raise requests.ConnectionError(
|
||||
f"No internet connection available. Please connect to a stable internet connection \n{e}" # E501
|
||||
f"No internet connection available. Please connect to a stable internet connection \n{ex}" # E501
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
except requests.RequestException as ex:
|
||||
raise requests.RequestException(
|
||||
f"An error occurred with the request. Please report this issue to Lightning Team \n{e}" # E501
|
||||
f"An error occurred with the request. Please report this issue to Lightning Team \n{ex}" # E501
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
|
@ -1354,5 +1354,5 @@ def get_unique_name():
|
|||
>>> get_unique_name()
|
||||
'truthful-dijkstra-2286'
|
||||
"""
|
||||
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999)
|
||||
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311
|
||||
return f"{adjective}-{surname}-{i}"
|
||||
|
|
|
@ -136,23 +136,23 @@ def _retry_wrapper(self, func: Callable) -> Callable:
|
|||
while _get_next_backoff_time(consecutive_errors) != _DEFAULT_BACKOFF_MAX:
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except lightning_cloud.openapi.rest.ApiException as e:
|
||||
except lightning_cloud.openapi.rest.ApiException as ex:
|
||||
# retry if the control plane fails with all errors except 4xx but not 408 - (Request Timeout)
|
||||
if e.status == 408 or e.status == 409 or not str(e.status).startswith("4"):
|
||||
if ex.status == 408 or ex.status == 409 or not str(ex.status).startswith("4"):
|
||||
consecutive_errors += 1
|
||||
backoff_time = _get_next_backoff_time(consecutive_errors)
|
||||
logger.debug(
|
||||
f"The {func.__name__} request failed to reach the server, got a response {e.status}."
|
||||
f"The {func.__name__} request failed to reach the server, got a response {ex.status}."
|
||||
f" Retrying after {backoff_time} seconds."
|
||||
)
|
||||
time.sleep(backoff_time)
|
||||
else:
|
||||
raise e
|
||||
except urllib3.exceptions.HTTPError as e:
|
||||
raise ex
|
||||
except urllib3.exceptions.HTTPError as ex:
|
||||
consecutive_errors += 1
|
||||
backoff_time = _get_next_backoff_time(consecutive_errors)
|
||||
logger.debug(
|
||||
f"The {func.__name__} request failed to reach the server, got a an error {str(e)}."
|
||||
f"The {func.__name__} request failed to reach the server, got a an error {str(ex)}."
|
||||
f" Retrying after {backoff_time} seconds."
|
||||
)
|
||||
time.sleep(backoff_time)
|
||||
|
|
|
@ -35,8 +35,8 @@ def string2dict(text):
|
|||
try:
|
||||
js = json.loads(text, object_pairs_hook=_duplicate_checker)
|
||||
return js
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Unable to load JSON: {str(e)}.")
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Unable to load JSON: {str(ex)}.")
|
||||
|
||||
|
||||
def is_openapi(obj):
|
||||
|
|
|
@ -49,7 +49,7 @@ def download_frontend(root: str = _PROJECT_ROOT):
|
|||
|
||||
shutil.rmtree(frontend_dir, ignore_errors=True)
|
||||
|
||||
response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL)
|
||||
response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL) # noqa: S310
|
||||
|
||||
file = tarfile.open(fileobj=response, mode="r|gz")
|
||||
file.extractall(path=download_dir)
|
||||
|
|
|
@ -297,10 +297,10 @@ class WorkStateObserver(Thread):
|
|||
try:
|
||||
with _state_observer_lock:
|
||||
self._work.apply_flow_delta(Delta(deep_diff, raise_errors=True))
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
print(traceback.print_exc())
|
||||
self._error_queue.put(e)
|
||||
raise e
|
||||
self._error_queue.put(ex)
|
||||
raise ex
|
||||
|
||||
def join(self, timeout: Optional[float] = None) -> None:
|
||||
self._exit_event.set()
|
||||
|
@ -415,19 +415,19 @@ class WorkRunner:
|
|||
self.state_observer.join(0)
|
||||
self.state_observer = None
|
||||
self.copier.join(0)
|
||||
except LightningSigtermStateException as e:
|
||||
except LightningSigtermStateException as ex:
|
||||
logger.debug("Exiting")
|
||||
os._exit(e.exit_code)
|
||||
except Exception as e:
|
||||
os._exit(ex.exit_code)
|
||||
except Exception as ex:
|
||||
# Inform the flow the work failed. This would fail the entire application.
|
||||
self.error_queue.put(e)
|
||||
self.error_queue.put(ex)
|
||||
# Terminate the threads
|
||||
if self.state_observer:
|
||||
if self.state_observer.started:
|
||||
self.state_observer.join(0)
|
||||
self.state_observer = None
|
||||
self.copier.join(0)
|
||||
raise e
|
||||
raise ex
|
||||
|
||||
def setup(self):
|
||||
from lightning.app.utilities.state import AppState
|
||||
|
@ -494,7 +494,7 @@ class WorkRunner:
|
|||
# Set the internal IP address.
|
||||
# Set this here after the state observer is initialized, since it needs to record it as a change and send
|
||||
# it back to the flow
|
||||
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0"
|
||||
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104
|
||||
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip)
|
||||
|
||||
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
|
||||
|
@ -525,9 +525,9 @@ class WorkRunner:
|
|||
# If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook.
|
||||
try:
|
||||
ret = self.run_executor_cls(self.work, work_run, self.delta_queue)(*args, **kwargs)
|
||||
except LightningSigtermStateException as e:
|
||||
raise e
|
||||
except BaseException as e:
|
||||
except LightningSigtermStateException as ex:
|
||||
raise ex
|
||||
except BaseException as ex:
|
||||
# 10.2 Send failed delta to the flow.
|
||||
reference_state = deepcopy(self.work.state)
|
||||
exp, val, tb = sys.exc_info()
|
||||
|
@ -559,7 +559,7 @@ class WorkRunner:
|
|||
id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state, verbose_level=2))
|
||||
)
|
||||
)
|
||||
self.work.on_exception(e)
|
||||
self.work.on_exception(ex)
|
||||
print("########## CAPTURED EXCEPTION ###########")
|
||||
print(traceback.print_exc())
|
||||
print("########## CAPTURED EXCEPTION ###########")
|
||||
|
|
|
@ -32,12 +32,9 @@ class SchedulerThread(threading.Thread):
|
|||
self._app = app
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
while not self._exit_event.is_set():
|
||||
self._exit_event.wait(self._sleep_time)
|
||||
self.run_once()
|
||||
except Exception as e:
|
||||
raise e
|
||||
while not self._exit_event.is_set():
|
||||
self._exit_event.wait(self._sleep_time)
|
||||
self.run_once()
|
||||
|
||||
def run_once(self):
|
||||
for call_hash in list(self._app._schedules.keys()):
|
||||
|
|
|
@ -166,8 +166,8 @@ class AppState:
|
|||
try:
|
||||
# TODO: Send the delta directly to the REST API.
|
||||
response = self._session.post(app_url, json=data, headers=headers)
|
||||
except ConnectionError as e:
|
||||
raise AttributeError("Failed to connect and send the app state. Is the app running?") from e
|
||||
except ConnectionError as ex:
|
||||
raise AttributeError("Failed to connect and send the app state. Is the app running?") from ex
|
||||
|
||||
if response and response.status_code != 200:
|
||||
raise Exception(f"The response from the server was {response.status_code}. Your inputs were rejected.")
|
||||
|
@ -186,8 +186,8 @@ class AppState:
|
|||
sleep(0.5)
|
||||
try:
|
||||
response = self._session.get(app_url, headers=headers, timeout=1)
|
||||
except ConnectionError as e:
|
||||
raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from e
|
||||
except ConnectionError as ex:
|
||||
raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from ex
|
||||
|
||||
self._authorized = response.status_code
|
||||
if self._authorized != 200:
|
||||
|
|
|
@ -331,9 +331,7 @@ def _device_count_nvml() -> int:
|
|||
for idx, val in enumerate(visible_devices):
|
||||
if cast(int, val) >= raw_cnt:
|
||||
return idx
|
||||
except OSError:
|
||||
return -1
|
||||
except AttributeError:
|
||||
except (OSError, AttributeError):
|
||||
return -1
|
||||
return len(visible_devices)
|
||||
|
||||
|
|
|
@ -200,8 +200,9 @@ class TensorBoardLogger(Logger):
|
|||
self.experiment.add_scalar(k, v, step)
|
||||
# TODO(fabric): specify the possible exception
|
||||
except Exception as ex:
|
||||
m = f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
|
||||
raise ValueError(m) from ex
|
||||
raise ValueError(
|
||||
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
|
||||
) from ex
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(
|
||||
|
|
|
@ -218,10 +218,10 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=batch_sampler.drop_last,
|
||||
)
|
||||
except TypeError as e:
|
||||
except TypeError as ex:
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
|
@ -232,7 +232,7 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
"We tried to re-instantiate your custom batch sampler and failed. "
|
||||
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
|
||||
"your custom batch sampler inside `*_dataloader` hooks of your module."
|
||||
) from e
|
||||
) from ex
|
||||
|
||||
return {
|
||||
"sampler": None,
|
||||
|
@ -257,12 +257,12 @@ def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optio
|
|||
|
||||
try:
|
||||
result = constructor(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
except TypeError as ex:
|
||||
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
||||
# `__init__` arguments map to one `DataLoader.__init__` argument
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
|
||||
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(ex))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
|
@ -274,7 +274,7 @@ def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optio
|
|||
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
||||
" This argument was automatically passed to your object by PyTorch Lightning."
|
||||
)
|
||||
raise MisconfigurationException(message) from e
|
||||
raise MisconfigurationException(message) from ex
|
||||
|
||||
attrs_record = getattr(orig_object, "__pl_attrs_record", [])
|
||||
for args, fn in attrs_record:
|
||||
|
|
|
@ -63,7 +63,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
|
|||
|
||||
|
||||
def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int:
|
||||
return random.randint(min_seed_value, max_seed_value)
|
||||
return random.randint(min_seed_value, max_seed_value) # noqa: S311
|
||||
|
||||
|
||||
def reset_seed() -> None:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Any, Callable
|
||||
|
@ -32,15 +33,12 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:
|
|||
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to lightning.pytorch,
|
||||
# which has to be redirected to the unified package:
|
||||
# https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
if hasattr(torchmetrics.utilities.imports, "_compare_version"):
|
||||
torchmetrics.utilities.imports._compare_version = compare_version # type: ignore
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
if hasattr(torchmetrics.metric, "_compare_version"):
|
||||
torchmetrics.metric._compare_version = compare_version
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
pickle.Unpickler = RedirectingUnpickler # type: ignore
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import csv
|
||||
import inspect
|
||||
import logging
|
||||
|
@ -272,10 +273,8 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di
|
|||
hparams = yaml.full_load(fp)
|
||||
|
||||
if _OMEGACONF_AVAILABLE and use_omegaconf:
|
||||
try:
|
||||
with contextlib.suppress(UnsupportedValueType, ValidationError):
|
||||
return OmegaConf.create(hparams)
|
||||
except (UnsupportedValueType, ValidationError):
|
||||
pass
|
||||
return hparams
|
||||
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ class _MNIST(Dataset):
|
|||
for url in self.RESOURCES:
|
||||
logging.info(f"Downloading {url}")
|
||||
fpath = os.path.join(data_folder, os.path.basename(url))
|
||||
urllib.request.urlretrieve(url, fpath)
|
||||
urllib.request.urlretrieve(url, fpath) # noqa: S310
|
||||
|
||||
@staticmethod
|
||||
def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]:
|
||||
|
@ -112,7 +112,7 @@ class _MNIST(Dataset):
|
|||
# todo: specify the possible exception
|
||||
except Exception as ex:
|
||||
exception = ex
|
||||
time.sleep(delta * random.random())
|
||||
time.sleep(delta * random.random()) # noqa: S311
|
||||
else:
|
||||
break
|
||||
assert res is not None
|
||||
|
@ -135,8 +135,8 @@ def MNIST(*args: Any, **kwargs: Any) -> Dataset:
|
|||
from torchvision.datasets import MNIST
|
||||
|
||||
MNIST(_DATASETS_PATH, download=True)
|
||||
except HTTPError as e:
|
||||
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
|
||||
except HTTPError as ex:
|
||||
print(f"Error {ex} downloading `torchvision.datasets.MNIST`")
|
||||
torchvision_mnist_available = False
|
||||
if not torchvision_mnist_available:
|
||||
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
|
||||
|
|
|
@ -41,9 +41,9 @@ if _COMET_AVAILABLE:
|
|||
|
||||
try:
|
||||
from comet_ml.api import API
|
||||
except ModuleNotFoundError: # pragma: no-cover
|
||||
except ModuleNotFoundError:
|
||||
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
|
||||
from comet_ml.papi import API # pragma: no-cover
|
||||
from comet_ml.papi import API
|
||||
else:
|
||||
# needed for test mocks, these tests shall be updated
|
||||
comet_ml = None
|
||||
|
|
|
@ -294,11 +294,9 @@ class NeptuneLogger(Logger):
|
|||
args["run"] = self._run_short_id
|
||||
|
||||
# Backward compatibility in case of previous version retrieval
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
if self._run_name is not None:
|
||||
args["name"] = self._run_name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return args
|
||||
|
||||
|
|
|
@ -53,9 +53,9 @@ class _DataFetcher(Iterator):
|
|||
assert self.iterator is not None
|
||||
try:
|
||||
data = next(self.iterator)
|
||||
except StopIteration as e:
|
||||
except StopIteration as ex:
|
||||
self.done = True
|
||||
raise e
|
||||
raise ex
|
||||
finally:
|
||||
self._stop_profiler()
|
||||
self.fetched += 1
|
||||
|
@ -128,9 +128,9 @@ class _PrefetchDataFetcher(_DataFetcher):
|
|||
self._fetch_next_batch(self.iterator)
|
||||
# consume the batch we just fetched
|
||||
batch = self.batches.pop(0)
|
||||
except StopIteration as e:
|
||||
except StopIteration as ex:
|
||||
self.done = True
|
||||
raise e
|
||||
raise ex
|
||||
else:
|
||||
# the iterator is empty
|
||||
raise StopIteration
|
||||
|
|
|
@ -41,8 +41,8 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):
|
|||
try:
|
||||
assert self.checkpoint_io is not None
|
||||
self.checkpoint_io.save_checkpoint(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
self._error = e
|
||||
except BaseException as ex:
|
||||
self._error = ex
|
||||
|
||||
self._executor.submit(_save_checkpoint, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
@ -99,11 +100,9 @@ class ServableModuleValidator(Callback):
|
|||
ready = False
|
||||
t0 = time.time()
|
||||
while not ready:
|
||||
try:
|
||||
with contextlib.suppress(requests.exceptions.ConnectionError):
|
||||
resp = requests.get(f"http://{self.host}:{self.port}/ping")
|
||||
ready = resp.status_code == 200
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
if time.time() - t0 > self.timeout:
|
||||
process.kill()
|
||||
raise Exception(f"The server didn't start within {self.timeout} seconds.")
|
||||
|
|
|
@ -282,8 +282,8 @@ class FSDPStrategy(ParallelStrategy):
|
|||
invalid_params_error = False
|
||||
try:
|
||||
super().setup_optimizers(trainer)
|
||||
except ValueError as e:
|
||||
if "optimizer got an empty parameter list" not in str(e):
|
||||
except ValueError as ex:
|
||||
if "optimizer got an empty parameter list" not in str(ex):
|
||||
raise
|
||||
invalid_params_error = True
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
|
@ -151,11 +152,9 @@ class _MaxSize(_ModeIterator[List]):
|
|||
out = [None] * n
|
||||
all_exhausted = True
|
||||
for i in range(n):
|
||||
try:
|
||||
with contextlib.suppress(StopIteration):
|
||||
out[i] = next(self.iterators[i])
|
||||
all_exhausted = False
|
||||
except StopIteration:
|
||||
pass
|
||||
if all_exhausted:
|
||||
raise StopIteration
|
||||
return out
|
||||
|
|
|
@ -297,10 +297,10 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
||||
)
|
||||
except TypeError as e:
|
||||
except TypeError as ex:
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
|
@ -311,7 +311,7 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
"We tried to re-instantiate your custom batch sampler and failed. "
|
||||
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
|
||||
"your custom batch sampler inside `*_dataloader` hooks of your module."
|
||||
) from e
|
||||
) from ex
|
||||
|
||||
if is_predicting:
|
||||
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler)
|
||||
|
|
|
@ -113,12 +113,12 @@ def _RunIf(
|
|||
if bf16_cuda:
|
||||
try:
|
||||
cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
|
||||
except (AssertionError, RuntimeError) as e:
|
||||
except (AssertionError, RuntimeError) as ex:
|
||||
# AssertionError: Torch not compiled with CUDA enabled
|
||||
# RuntimeError: Found no NVIDIA driver on your system.
|
||||
is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e)
|
||||
is_unrelated = "Found no NVIDIA driver" not in str(ex) or "Torch not compiled with CUDA" not in str(ex)
|
||||
if is_unrelated:
|
||||
raise e
|
||||
raise ex
|
||||
cond = True
|
||||
|
||||
conditions.append(cond)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
|
@ -38,7 +39,7 @@ def pytest_sessionfinish(session, exitstatus):
|
|||
# TODO this isn't great. We should have each tests doing it's own cleanup
|
||||
current_process = psutil.Process()
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
with contextlib.suppress(psutil.NoSuchProcess):
|
||||
params = child.as_dict() or {}
|
||||
cmd_lines = params.get("cmdline", [])
|
||||
# we shouldn't kill the resource tracker from multiprocessing. If we do,
|
||||
|
@ -46,8 +47,6 @@ def pytest_sessionfinish(session, exitstatus):
|
|||
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
|
||||
continue
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
main_thread = threading.current_thread()
|
||||
for t in threading.enumerate():
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
@ -21,17 +22,12 @@ def validate_app_functionalities(app_page: "Page") -> None:
|
|||
app_page: The UI page of the app to be validated.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
with contextlib.suppress(playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
|
||||
app_page.reload()
|
||||
sleep(5)
|
||||
app_label = app_page.frame_locator("iframe").locator("text=Choose your AI task")
|
||||
app_label.wait_for(timeout=30 * 1000)
|
||||
break
|
||||
except (
|
||||
playwright._impl._api_types.Error,
|
||||
playwright._impl._api_types.TimeoutError,
|
||||
):
|
||||
pass
|
||||
|
||||
input_field = app_page.frame_locator("iframe").locator('input:below(:text("Data URL"))').first
|
||||
input_field.wait_for(timeout=1000)
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_list_all_apps_paginated(list_memberships: mock.MagicMock, list_instance
|
|||
list_memberships.assert_called_once()
|
||||
assert list_instances.mock_calls == [
|
||||
mock.call(project_id="default-project", limit=100, phase_in=[]),
|
||||
mock.call(project_id="default-project", page_token="page-2", limit=100, phase_in=[]),
|
||||
mock.call(project_id="default-project", page_token="page-2", limit=100, phase_in=[]), # noqa: S106
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ def test_valid_unpublished_app_name():
|
|||
subprocess.check_output(f"lightning install app {real_app}", shell=True, stderr=subprocess.STDOUT)
|
||||
# this condition should never be hit
|
||||
assert False
|
||||
except subprocess.CalledProcessError as e:
|
||||
assert "WARNING" in str(e.output)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
assert "WARNING" in str(ex.output)
|
||||
|
||||
# assert aborted install
|
||||
result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app], input="q")
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_client_server():
|
|||
|
||||
secrets = [Secret(name="example", value="secret")]
|
||||
|
||||
general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a")
|
||||
general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a") # noqa: S106
|
||||
assert general.cls_name == "TestConfig"
|
||||
assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}'
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
|
@ -37,7 +38,7 @@ def pytest_sessionfinish(session, exitstatus):
|
|||
# TODO this isn't great. We should have each tests doing it's own cleanup
|
||||
current_process = psutil.Process()
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
with contextlib.suppress(psutil.NoSuchProcess):
|
||||
params = child.as_dict() or {}
|
||||
cmd_lines = params.get("cmdline", [])
|
||||
# we shouldn't kill the resource tracker from multiprocessing. If we do,
|
||||
|
@ -45,8 +46,6 @@ def pytest_sessionfinish(session, exitstatus):
|
|||
if cmd_lines and "resource_tracker" in cmd_lines[-1]:
|
||||
continue
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
main_thread = threading.current_thread()
|
||||
for t in threading.enumerate():
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
|
@ -555,12 +556,11 @@ class CheckpointLightningApp(LightningApp):
|
|||
|
||||
|
||||
def test_snap_shotting():
|
||||
try:
|
||||
with contextlib.suppress(SuccessException):
|
||||
app = CheckpointLightningApp(FlowA())
|
||||
app.checkpointing = True
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
except SuccessException:
|
||||
pass
|
||||
|
||||
checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints")
|
||||
checkpoints = os.listdir(checkpoint_dir)
|
||||
assert len(checkpoints) == 1
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter
|
||||
|
@ -374,11 +375,9 @@ def test_lightning_flow_and_work():
|
|||
"changes": {},
|
||||
}
|
||||
assert flow_a.state == state
|
||||
try:
|
||||
with contextlib.suppress(ExitAppException):
|
||||
while True:
|
||||
flow_a.run()
|
||||
except ExitAppException:
|
||||
pass
|
||||
|
||||
state = {
|
||||
"vars": {"counter": 5, "_layout": ANY, "_paths": {}},
|
||||
|
|
|
@ -24,7 +24,7 @@ def mock_plugin_server(mock_uvicorn) -> TestClient:
|
|||
|
||||
mock_uvicorn.run.side_effect = create_test_client
|
||||
|
||||
_start_plugin_server("0.0.0.0", 8888)
|
||||
_start_plugin_server("0.0.0.0", 8888) # noqa: S104
|
||||
|
||||
return test_client["client"]
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class StartFrontendServersTestFlow(LightningFlow):
|
|||
"cloudspace_host, port, expected_host, expected_target",
|
||||
[
|
||||
(None, 7000, "localhost", "http://localhost:7000"),
|
||||
("test.lightning.ai", 7000, "0.0.0.0", "https://7000-test.lightning.ai"),
|
||||
("test.lightning.ai", 7000, "0.0.0.0", "https://7000-test.lightning.ai"), # noqa: S104
|
||||
],
|
||||
)
|
||||
@mock.patch("lightning.app.runners.multiprocess.find_free_network_port")
|
||||
|
|
|
@ -267,9 +267,9 @@ class WorkRunnerPatch(WorkRunner):
|
|||
ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state, verbose_level=2)))
|
||||
)
|
||||
self.counter += 1
|
||||
except Exception as e:
|
||||
except Exception as ex:
|
||||
logger.error(traceback.format_exc())
|
||||
self.error_queue.put(e)
|
||||
self.error_queue.put(ex)
|
||||
raise ExitAppException
|
||||
|
||||
|
||||
|
@ -435,7 +435,6 @@ def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get):
|
|||
copy_request_queue=_MockQueue(),
|
||||
copy_response_queue=_MockQueue(),
|
||||
)
|
||||
|
||||
with contextlib.suppress(ExitAppException):
|
||||
runner()
|
||||
|
||||
|
@ -645,7 +644,7 @@ def test_state_observer():
|
|||
"patch_constants, environment, expected_ip_addr",
|
||||
[
|
||||
({}, {}, "127.0.0.1"),
|
||||
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"),
|
||||
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), # noqa: S104
|
||||
({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"),
|
||||
],
|
||||
indirect=["patch_constants"],
|
||||
|
|
|
@ -98,12 +98,12 @@ class RunIf:
|
|||
if bf16_cuda:
|
||||
try:
|
||||
cond = not (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
|
||||
except (AssertionError, RuntimeError) as e:
|
||||
except (AssertionError, RuntimeError) as ex:
|
||||
# AssertionError: Torch not compiled with CUDA enabled
|
||||
# RuntimeError: Found no NVIDIA driver on your system.
|
||||
is_unrelated = "Found no NVIDIA driver" not in str(e) or "Torch not compiled with CUDA" not in str(e)
|
||||
is_unrelated = "Found no NVIDIA driver" not in str(ex) or "Torch not compiled with CUDA" not in str(ex)
|
||||
if is_unrelated:
|
||||
raise e
|
||||
raise ex
|
||||
cond = True
|
||||
|
||||
conditions.append(cond)
|
||||
|
|
|
@ -77,13 +77,13 @@ class _MyFabricGradNorm(BoringFabric):
|
|||
try:
|
||||
super().run()
|
||||
break
|
||||
except RuntimeError as e:
|
||||
except RuntimeError as ex:
|
||||
# nonfinite grads -> skip and continue
|
||||
# this may repeat until the scaler finds a factor where overflow is avoided,
|
||||
# so the while loop should eventually break
|
||||
# stop after a max of 10 tries
|
||||
if i > 10 or not str(e).startswith("The total norm"):
|
||||
raise e
|
||||
if i > 10 or not str(ex).startswith("The total norm"):
|
||||
raise ex
|
||||
|
||||
# unscale was already called by last attempt,
|
||||
# but no update afterwards since optimizer step was missing.
|
||||
|
@ -117,13 +117,13 @@ class _MyFabricGradVal(BoringFabric):
|
|||
try:
|
||||
super().run()
|
||||
break
|
||||
except RuntimeError as e:
|
||||
except RuntimeError as ex:
|
||||
# nonfinite grads -> skip and continue
|
||||
# this may repeat until the scaler finds a factor where overflow is avoided,
|
||||
# so the while loop should eventually break
|
||||
# stop after a max of 10 tries
|
||||
if i > 10 or not str(e).startswith("Nonfinite grads"):
|
||||
raise e
|
||||
if i > 10 or not str(ex).startswith("Nonfinite grads"):
|
||||
raise ex
|
||||
|
||||
# unscale was already called by last attempt,
|
||||
# but no update afterwards since optimizer step was missing.
|
||||
|
|
|
@ -103,7 +103,7 @@ class MNIST(Dataset):
|
|||
for url in self.RESOURCES:
|
||||
logging.info(f"Downloading {url}")
|
||||
fpath = os.path.join(data_folder, os.path.basename(url))
|
||||
urllib.request.urlretrieve(url, fpath)
|
||||
urllib.request.urlretrieve(url, fpath) # noqa: S310
|
||||
|
||||
@staticmethod
|
||||
def _try_load(path_data, trials: int = 30, delta: float = 1.0):
|
||||
|
@ -115,8 +115,8 @@ class MNIST(Dataset):
|
|||
try:
|
||||
res = torch.load(path_data)
|
||||
# todo: specify the possible exception
|
||||
except Exception as e:
|
||||
exception = e
|
||||
except Exception as ex:
|
||||
exception = ex
|
||||
time.sleep(delta * random.random())
|
||||
else:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue