[pre-commit.ci] pre-commit suggestions (#19229)
* [pre-commit.ci] pre-commit suggestions updates: - [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0) - [github.com/asottile/pyupgrade: v3.14.0 → v3.15.0](https://github.com/asottile/pyupgrade/compare/v3.14.0...v3.15.0) - [github.com/astral-sh/ruff-pre-commit: v0.1.3 → v0.1.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.3...v0.1.9) - [github.com/psf/black: 23.9.1 → 23.12.1](https://github.com/psf/black/compare/23.9.1...23.12.1) - [github.com/pre-commit/mirrors-prettier: v3.0.3 → v4.0.0-alpha.8](https://github.com/pre-commit/mirrors-prettier/compare/v3.0.3...v4.0.0-alpha.8) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update .pre-commit-config.yaml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * drop unused --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
abf5aab6b5
commit
f120c91e9f
|
@ -234,7 +234,7 @@ def _download_frontend(pkg_path: str, version: str = "v0.0.0"):
|
|||
response = urllib.request.urlopen(frontend_release_url)
|
||||
|
||||
file = tarfile.open(fileobj=response, mode="r|gz")
|
||||
file.extractall(path=download_dir)
|
||||
file.extractall(path=download_dir) # noqa: S202
|
||||
|
||||
shutil.move(download_dir, frontend_dir)
|
||||
print("The Lightning UI has successfully been downloaded!")
|
||||
|
@ -468,7 +468,7 @@ class AssistantCLI:
|
|||
raise RuntimeError(f"Requesting file '{zip_url}' does not exist or it is just unavailable.")
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(tmp)
|
||||
zip_ref.extractall(tmp) # noqa: S202
|
||||
|
||||
zip_dirs = [d for d in glob.glob(os.path.join(tmp, "*")) if os.path.isdir(d)]
|
||||
# check that the extracted archive has only repo folder
|
||||
|
|
|
@ -23,7 +23,7 @@ ci:
|
|||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
@ -51,7 +51,7 @@ repos:
|
|||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.14.0
|
||||
rev: v3.15.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: ["--py38-plus"]
|
||||
|
@ -84,13 +84,13 @@ repos:
|
|||
- flake8-return
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.1.3"
|
||||
rev: "v0.1.9"
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix", "--preview"]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.9.1
|
||||
rev: 23.12.1
|
||||
hooks:
|
||||
- id: black
|
||||
name: Format code
|
||||
|
@ -121,7 +121,7 @@ repos:
|
|||
)$
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.0.3
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
# https://prettier.io/docs/en/options.html#print-width
|
||||
|
|
|
@ -48,7 +48,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
|
|||
if ".zip" in local_filename:
|
||||
if os.path.exists(local_filename):
|
||||
with zipfile.ZipFile(local_filename, "r") as zip_ref:
|
||||
zip_ref.extractall(path)
|
||||
zip_ref.extractall(path) # noqa: S202
|
||||
elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"):
|
||||
extract_tarfile(local_filename, path, "r:gz")
|
||||
elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"):
|
||||
|
|
|
@ -173,7 +173,7 @@ def test(
|
|||
action = agent.get_greedy_action(next_obs)
|
||||
|
||||
# Single environment step
|
||||
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy())
|
||||
next_obs, reward, done, truncated, _ = env.step(action.cpu().numpy())
|
||||
done = done or truncated
|
||||
cumulative_rew += reward
|
||||
next_obs = torch.tensor(next_obs, device=device)
|
||||
|
|
|
@ -92,7 +92,7 @@ class LitClassifier(LightningModule):
|
|||
self.log("test_loss", loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
x, y = batch
|
||||
x, _ = batch
|
||||
return self(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
|
|
|
@ -107,6 +107,7 @@ ignore-init-module-imports = true
|
|||
"S113", # todo: Probable use of requests call without timeout
|
||||
"S311", # todo: Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"S108", # todo: Probable insecure usage of temporary file or directory: "/tmp/sys-customizations-sync"
|
||||
"S202", # Uses of `tarfile.extractall()`
|
||||
"S602", # todo: `subprocess` call with `shell=True` identified, security issue
|
||||
"S603", # todo: `subprocess` call: check for execution of untrusted input
|
||||
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
|
||||
|
|
|
@ -27,7 +27,7 @@ def app(app_name: str) -> None:
|
|||
app_name = _capture_valid_app_component_name(resource_type="app")
|
||||
|
||||
# generate resource template
|
||||
new_resource_name, name_for_files = _make_resource(resource_dir="app-template", resource_name=app_name)
|
||||
new_resource_name, _ = _make_resource(resource_dir="app-template", resource_name=app_name)
|
||||
|
||||
m = f"""
|
||||
⚡ Lightning app template created! ⚡
|
||||
|
|
|
@ -121,7 +121,7 @@ def download_frontend(destination: Path) -> None:
|
|||
with TemporaryDirectory() as download_dir:
|
||||
response = urllib.request.urlopen(url) # noqa: S310
|
||||
file = tarfile.open(fileobj=response, mode="r|gz")
|
||||
file.extractall(path=download_dir)
|
||||
file.extractall(path=download_dir) # noqa: S202
|
||||
shutil.move(str(Path(download_dir, build_dir_name)), destination)
|
||||
|
||||
|
||||
|
|
|
@ -154,7 +154,7 @@ def _run_plugin(run: _Run) -> Dict[str, Any]:
|
|||
logger.info("Extracting plugin source.")
|
||||
|
||||
with tarfile.open(download_path, "r:gz") as tf:
|
||||
tf.extractall(source_path)
|
||||
tf.extractall(source_path) # noqa: S202
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
@ -334,7 +334,7 @@ class Drive:
|
|||
|
||||
|
||||
def _maybe_create_drive(component_name: str, state: Dict) -> Union[Dict, Drive]:
|
||||
if state.get("type", None) == Drive.__IDENTIFIER__:
|
||||
if state.get("type") == Drive.__IDENTIFIER__:
|
||||
drive = Drive.from_dict(state)
|
||||
drive.component_name = component_name
|
||||
return drive
|
||||
|
|
|
@ -156,7 +156,7 @@ class CloudCompute:
|
|||
f"mounts argument must be one of [None, Mount, List[Mount]], "
|
||||
f"received {mounts} of type {type(mounts)}"
|
||||
)
|
||||
_verify_mount_root_dirs_are_unique(d.get("mounts", None))
|
||||
_verify_mount_root_dirs_are_unique(d.get("mounts"))
|
||||
return cls(**d)
|
||||
|
||||
@property
|
||||
|
@ -183,6 +183,6 @@ def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], T
|
|||
|
||||
|
||||
def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
|
||||
if state and state.get("type", None) == __CLOUD_COMPUTE_IDENTIFIER__:
|
||||
if state and state.get("type") == __CLOUD_COMPUTE_IDENTIFIER__:
|
||||
return CloudCompute.from_dict(state)
|
||||
return state
|
||||
|
|
|
@ -52,7 +52,7 @@ def download_frontend(root: str = _PROJECT_ROOT):
|
|||
response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL) # noqa: S310
|
||||
|
||||
file = tarfile.open(fileobj=response, mode="r|gz")
|
||||
file.extractall(path=download_dir)
|
||||
file.extractall(path=download_dir) # noqa: S202
|
||||
|
||||
shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
|
||||
print("The Lightning UI has successfully been downloaded!")
|
||||
|
|
|
@ -258,7 +258,7 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
|
|||
"""Computes the total size in bytes of all file paths for every datastructure in the given list."""
|
||||
item_sizes = []
|
||||
for item in items:
|
||||
flattened_item, spec = tree_flatten(item)
|
||||
flattened_item, _ = tree_flatten(item)
|
||||
|
||||
num_bytes = 0
|
||||
for index, element in enumerate(flattened_item):
|
||||
|
|
|
@ -252,7 +252,7 @@ def _import_bitsandbytes() -> ModuleType:
|
|||
if int8params.has_fp16_weights:
|
||||
int8params.data = B
|
||||
else:
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
CB, CBt, SCB, SCBt, _ = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
int8params.data = CB
|
||||
|
|
|
@ -163,10 +163,11 @@ def _basic_subprocess_cmd() -> Sequence[str]:
|
|||
|
||||
|
||||
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
|
||||
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
|
||||
# when user is using hydra find the absolute path
|
||||
if __main__.__spec__ is None: # pragma: no-cover
|
||||
command = [sys.executable, to_absolute_path(sys.argv[0])]
|
||||
|
|
|
@ -109,7 +109,7 @@ class _NotYetLoadedTensor:
|
|||
def _load_tensor(self) -> Tensor:
|
||||
from torch.storage import TypedStorage, UntypedStorage
|
||||
|
||||
name, storage_cls, fn, device, size = self.storageinfo
|
||||
_, _, fn, _, size = self.storageinfo
|
||||
dtype = self.metatensor.dtype
|
||||
|
||||
storage = self.archiveinfo.file_reader.get_storage_from_record(
|
||||
|
@ -182,7 +182,7 @@ class _LazyLoadingUnpickler(pickle.Unpickler):
|
|||
def persistent_load(self, pid: tuple) -> "TypedStorage":
|
||||
from torch.storage import TypedStorage
|
||||
|
||||
name, cls, fn, device, size = pid
|
||||
_, cls, _, _, _ = pid
|
||||
with warnings.catch_warnings():
|
||||
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
|
||||
warnings.simplefilter("ignore")
|
||||
|
|
|
@ -302,7 +302,7 @@ class ModelPruning(Callback):
|
|||
copy = deepcopy(copy) # keep the original parameters
|
||||
copy.reset_parameters()
|
||||
for i, name in names:
|
||||
new, new_name = self._parameters_to_prune[i]
|
||||
new, _ = self._parameters_to_prune[i]
|
||||
self._copy_param(new, copy, name)
|
||||
|
||||
def _apply_local_pruning(self, amount: float) -> None:
|
||||
|
|
|
@ -53,7 +53,7 @@ class Transformer(nn.Module):
|
|||
self.src_mask = None
|
||||
|
||||
def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||
b, t = inputs.shape
|
||||
_, t = inputs.shape
|
||||
|
||||
# we assume target is already shifted w.r.t. inputs
|
||||
if mask is None:
|
||||
|
|
|
@ -11,7 +11,7 @@ def _is_leaf_or_primitive_container(pytree: PyTree) -> bool:
|
|||
|
||||
node_type = _get_node_type(pytree)
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(pytree)
|
||||
child_pytrees, _ = flatten_fn(pytree)
|
||||
return all(isinstance(child, (int, float, str)) for child in child_pytrees)
|
||||
|
||||
|
||||
|
|
|
@ -69,5 +69,5 @@ def validate_app_functionalities(app_page: "Page") -> None:
|
|||
|
||||
@pytest.mark.cloud()
|
||||
def test_app_cloud() -> None:
|
||||
with run_app_in_cloud(_PATH_INTEGRATIONS_DIR) as (admin_page, view_page, fetch_logs, _):
|
||||
with run_app_in_cloud(_PATH_INTEGRATIONS_DIR) as (_, view_page, _, _):
|
||||
validate_app_functionalities(view_page)
|
||||
|
|
|
@ -68,7 +68,7 @@ def test_v0_app_example_byoc_cloud() -> None:
|
|||
with run_app_in_cloud(
|
||||
os.path.join(_PATH_EXAMPLES, "v0"),
|
||||
extra_args=["--cluster-id", os.environ.get("LIGHTNING_BYOC_CLUSTER_ID")],
|
||||
) as (_, view_page, fetch_logs, app_name):
|
||||
) as (_, view_page, fetch_logs, _):
|
||||
run_v0_app(fetch_logs, view_page)
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_fabric_run_executor_mps_forced_cpu(accelerator_given, accelerator_expec
|
|||
warning_context = no_warning_call(match=warning_str + "*")
|
||||
|
||||
with warning_context:
|
||||
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
|
||||
ret_val, _ = _get_args_after_tracer_injection(accelerator=accelerator_given)
|
||||
assert ret_val["accelerator"] == accelerator_expected
|
||||
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expe
|
|||
warning_context = no_warning_call(match=warning_str + "*")
|
||||
|
||||
with warning_context:
|
||||
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
|
||||
ret_val, _ = _get_args_after_tracer_injection(accelerator=accelerator_given)
|
||||
assert ret_val["accelerator"] == accelerator_expected
|
||||
|
||||
|
||||
|
|
|
@ -1115,7 +1115,7 @@ class FlowWrapper(LightningFlow):
|
|||
def test_cloud_compute_binding():
|
||||
cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = True
|
||||
|
||||
assert {} == cloud_compute._CLOUD_COMPUTE_STORE
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE == {}
|
||||
flow = FlowCC()
|
||||
assert len(cloud_compute._CLOUD_COMPUTE_STORE) == 2
|
||||
assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.work_c"]
|
||||
|
|
|
@ -550,14 +550,14 @@ class TestAppCreationClient:
|
|||
|
||||
# testing with no-cache False
|
||||
cloud_runtime.dispatch(no_cache=False)
|
||||
(func_name, args, kwargs) = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
|
||||
_, _, kwargs = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
|
||||
body = kwargs["body"]
|
||||
assert body.dependency_cache_key == "dummy-hash"
|
||||
|
||||
# testing with no-cache True
|
||||
mock_client.reset_mock()
|
||||
cloud_runtime.dispatch(no_cache=True)
|
||||
(func_name, args, kwargs) = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
|
||||
_, _, kwargs = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
|
||||
body = kwargs["body"]
|
||||
assert body.dependency_cache_key is None
|
||||
|
||||
|
|
|
@ -560,7 +560,7 @@ def test_dataloader_kwargs_replacement_with_array_default_comparison():
|
|||
self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==
|
||||
|
||||
dataloader = ArrayAttributeDataloader(dataset)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
|
||||
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
|
||||
assert dl_kwargs["indices"] is dataloader.indices
|
||||
|
||||
|
||||
|
|
|
@ -117,17 +117,17 @@ def test_rich_progress_bar_custom_theme():
|
|||
progress_bar.on_train_start(Trainer(), BoringModel())
|
||||
|
||||
assert progress_bar.theme == theme
|
||||
args, kwargs = mocks["CustomBarColumn"].call_args
|
||||
_, kwargs = mocks["CustomBarColumn"].call_args
|
||||
assert kwargs["complete_style"] == theme.progress_bar
|
||||
assert kwargs["finished_style"] == theme.progress_bar_finished
|
||||
|
||||
args, kwargs = mocks["BatchesProcessedColumn"].call_args
|
||||
_, kwargs = mocks["BatchesProcessedColumn"].call_args
|
||||
assert kwargs["style"] == theme.batch_progress
|
||||
|
||||
args, kwargs = mocks["CustomTimeColumn"].call_args
|
||||
_, kwargs = mocks["CustomTimeColumn"].call_args
|
||||
assert kwargs["style"] == theme.time
|
||||
|
||||
args, kwargs = mocks["ProcessingSpeedColumn"].call_args
|
||||
_, kwargs = mocks["ProcessingSpeedColumn"].call_args
|
||||
assert kwargs["style"] == theme.processing_speed
|
||||
|
||||
|
||||
|
|
|
@ -61,5 +61,5 @@ def test_rich_summary_tuples(mock_table_add_row, mock_console):
|
|||
# ensure that summary was logged + the breakdown of model parameters
|
||||
assert mock_console.call_count == 2
|
||||
# assert that the input summary data was converted correctly
|
||||
args, kwargs = mock_table_add_row.call_args_list[0]
|
||||
args, _ = mock_table_add_row.call_args_list[0]
|
||||
assert args[1:] == ("0", "layer", "Linear", "66 ", "[4, 32]", "[4, 2]")
|
||||
|
|
|
@ -171,7 +171,7 @@ class ParityModuleRNN(LightningModule):
|
|||
self._loss = [] # needed for checking if the loss is the same as vanilla torch
|
||||
|
||||
def forward(self, x):
|
||||
seq, last = self.rnn(x)
|
||||
seq, _ = self.rnn(x)
|
||||
return self.linear_out(seq)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
|
|
|
@ -179,9 +179,9 @@ def test_fetching_dataloader_iter_opt(automatic_optimization, tmpdir):
|
|||
def training_step(self, dataloader_iter):
|
||||
assert isinstance(self.trainer.fit_loop._data_fetcher, _DataLoaderIterDataFetcher)
|
||||
# fetch 2 batches
|
||||
batch, batch_idx, dataloader_idx = next(dataloader_iter)
|
||||
batch, batch_idx, _ = next(dataloader_iter)
|
||||
self.batches.append(batch)
|
||||
batch, batch_idx, dataloader_idx = next(dataloader_iter)
|
||||
batch, batch_idx, _ = next(dataloader_iter)
|
||||
self.batches.append(batch)
|
||||
|
||||
batch = self.batches.pop(0)
|
||||
|
@ -216,7 +216,7 @@ def test_fetching_dataloader_iter_running_stages(fn, tmp_path):
|
|||
class TestModel(BoringModel):
|
||||
def fetch(self, data_fetcher, dataloader_iter):
|
||||
assert isinstance(data_fetcher, _DataLoaderIterDataFetcher)
|
||||
batch, batch_idx, dataloader_idx = next(dataloader_iter)
|
||||
batch, batch_idx, _ = next(dataloader_iter)
|
||||
assert data_fetcher.fetched == batch_idx + 1
|
||||
return batch
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ def test_fabric_boring_lightning_module_manual():
|
|||
optimizers, _ = module.configure_optimizers()
|
||||
dataloader = module.train_dataloader()
|
||||
|
||||
model, optimizer = fabric.setup(module, optimizers[0])
|
||||
model, _ = fabric.setup(module, optimizers[0])
|
||||
dataloader = fabric.setup_dataloaders(dataloader)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
|
|
@ -111,8 +111,8 @@ def test_verbose_param(tmpdir, capsys):
|
|||
|
||||
with patch("torch.onnx.log", autospec=True) as test:
|
||||
model.to_onnx(file_path, verbose=True)
|
||||
args, kwargs = test.call_args
|
||||
prefix, graph = args
|
||||
args, _ = test.call_args
|
||||
prefix, _ = args
|
||||
assert prefix == "Exported graph: "
|
||||
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class RandomFloatIntDataset(Dataset):
|
|||
|
||||
class DoublePrecisionBoringModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
float_data, int_data = batch
|
||||
float_data, _ = batch
|
||||
assert torch.tensor([0.0]).dtype == torch.float64
|
||||
assert torch.tensor([0.0], dtype=torch.float16).dtype == torch.float16
|
||||
assert float_data.dtype == torch.float64
|
||||
|
|
|
@ -169,7 +169,7 @@ def test_amp_skip_optimizer(tmpdir):
|
|||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
opt1, opt2 = self.optimizers()
|
||||
_, opt2 = self.optimizers()
|
||||
output = self(batch)
|
||||
loss = self.loss(output)
|
||||
opt2.zero_grad()
|
||||
|
|
|
@ -114,7 +114,7 @@ class TestBoringModel(BoringModel):
|
|||
|
||||
self.save_hyperparameters()
|
||||
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
||||
self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params]
|
||||
self.should_be_wrapped = [wrap_min_params < (32 * 32 + 32), None, wrap_min_params < (32 * 2 + 2)]
|
||||
|
||||
def configure_optimizers(self):
|
||||
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
|
||||
|
|
|
@ -851,7 +851,7 @@ def test_native_print_results_encodings(monkeypatch, encoding):
|
|||
|
||||
# Attempt to encode everything the file is told to write with the given encoding
|
||||
for call_ in out.method_calls:
|
||||
name, args, kwargs = call_
|
||||
name, args, _ = call_
|
||||
if name == "write":
|
||||
args[0].encode(encoding)
|
||||
|
||||
|
|
|
@ -610,12 +610,12 @@ class TesManualOptimizationDDPModel(BoringModel):
|
|||
assert torch.equal(self.layer.weight.grad, grad_clone)
|
||||
|
||||
def gen_closure():
|
||||
loss_ones_gen, loss_zeros = compute_loss()
|
||||
loss_ones_gen, _ = compute_loss()
|
||||
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
|
||||
make_manual_backward(loss_ones_gen, make_optimizer_step=make_gen_optimizer_step)
|
||||
|
||||
def dis_closure():
|
||||
loss_ones_gen, loss_zeros = compute_loss()
|
||||
loss_ones_gen, _ = compute_loss()
|
||||
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
|
||||
make_manual_backward(loss_ones_gen, make_optimizer_step=make_dis_optimizer_step)
|
||||
|
||||
|
@ -719,12 +719,12 @@ class TestManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel):
|
|||
assert torch.equal(self.layer.weight.grad, grad_clone)
|
||||
|
||||
def gen_closure():
|
||||
loss_ones_gen, loss_zeros = compute_loss()
|
||||
loss_ones_gen, _ = compute_loss()
|
||||
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_gen_optimizer_step)
|
||||
make_manual_backward(loss_ones_gen, make_optimizer_step=make_gen_optimizer_step)
|
||||
|
||||
def dis_closure():
|
||||
loss_ones_gen, loss_zeros = compute_loss()
|
||||
loss_ones_gen, _ = compute_loss()
|
||||
make_manual_backward(loss_ones_gen, retain_graph=True, make_optimizer_step=make_dis_optimizer_step)
|
||||
make_manual_backward(loss_ones_gen, make_optimizer_step=make_dis_optimizer_step)
|
||||
|
||||
|
|
|
@ -324,5 +324,5 @@ def test_dataloader_kwargs_replacement_with_array_default_comparison():
|
|||
self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==
|
||||
|
||||
dataloader = ArrayAttributeDataloader(dataset)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
|
||||
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
|
||||
assert dl_kwargs["indices"] is dataloader.indices
|
||||
|
|
Loading…
Reference in New Issue