lightning/tests/tests_fabric/utilities/test_optimizer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

37 lines
1.2 KiB
Python
Raw Normal View History

import collections
import dataclasses
import torch
from lightning.fabric.utilities.optimizer import _optimizer_to_device
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from torch import Tensor
def test_optimizer_to_device():
@dataclasses.dataclass(frozen=True)
class FooState:
bar: int
class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)
self.state["frozen"] = FooState(0)
layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
_optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
_optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")
def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
2022-12-24 06:44:27 +00:00
if isinstance(param, Tensor):
assert param.data.device.type == device
elif isinstance(param, collections.abc.Mapping):
for subparam in param.values():
2022-12-24 06:44:27 +00:00
if isinstance(subparam, Tensor):
assert param.data.device.type == device