2022-02-18 08:36:07 +00:00
|
|
|
import collections
|
2023-03-10 15:37:18 +00:00
|
|
|
import dataclasses
|
2022-02-18 08:36:07 +00:00
|
|
|
|
|
|
|
import torch
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.utilities.optimizer import _optimizer_to_device
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from torch import Tensor
|
2022-02-18 08:36:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_optimizer_to_device():
|
2023-03-10 15:37:18 +00:00
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class FooState:
|
|
|
|
bar: int
|
|
|
|
|
2022-02-18 08:36:07 +00:00
|
|
|
class TestOptimizer(torch.optim.SGD):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.state["dummy"] = torch.tensor(0)
|
2023-03-10 15:37:18 +00:00
|
|
|
self.state["frozen"] = FooState(0)
|
2022-02-18 08:36:07 +00:00
|
|
|
|
|
|
|
layer = torch.nn.Linear(32, 2)
|
|
|
|
opt = TestOptimizer(layer.parameters(), lr=0.1)
|
2022-10-26 12:51:50 +00:00
|
|
|
_optimizer_to_device(opt, "cpu")
|
2022-02-18 08:36:07 +00:00
|
|
|
if torch.cuda.is_available():
|
2022-10-26 12:51:50 +00:00
|
|
|
_optimizer_to_device(opt, "cuda")
|
2022-02-18 08:36:07 +00:00
|
|
|
assert_opt_parameters_on_device(opt, "cuda")
|
|
|
|
|
|
|
|
|
|
|
|
def assert_opt_parameters_on_device(opt, device: str):
|
|
|
|
for param in opt.state.values():
|
|
|
|
# Not sure there are any global tensors in the state dict
|
2022-12-24 06:44:27 +00:00
|
|
|
if isinstance(param, Tensor):
|
2022-02-18 08:36:07 +00:00
|
|
|
assert param.data.device.type == device
|
2023-01-04 23:48:35 +00:00
|
|
|
elif isinstance(param, collections.abc.Mapping):
|
2022-02-18 08:36:07 +00:00
|
|
|
for subparam in param.values():
|
2022-12-24 06:44:27 +00:00
|
|
|
if isinstance(subparam, Tensor):
|
2022-02-18 08:36:07 +00:00
|
|
|
assert param.data.device.type == device
|