279 lines
12 KiB
Python
279 lines
12 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from unittest import mock
|
|
from unittest.mock import Mock, PropertyMock
|
|
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
|
|
from lightning.fabric.strategies import SingleDeviceStrategy
|
|
from lightning.fabric.utilities.types import _Stateful
|
|
|
|
from tests_fabric.helpers.runif import RunIf
|
|
|
|
|
|
@pytest.mark.parametrize("is_rank_zero", [True, False])
|
|
def test_save_checkpoint_rank_zero_only(is_rank_zero, tmp_path):
|
|
"""Test that the checkpoint only gets saved on global rank 0 in the base implementation in Strategy."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
save_checkpoint_mock = Mock()
|
|
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
|
|
with mock.patch(
|
|
"lightning.fabric.strategies.single_device.SingleDeviceStrategy.is_global_zero",
|
|
new_callable=PropertyMock(return_value=is_rank_zero),
|
|
):
|
|
strategy.save_checkpoint(tmp_path, {"anything": 1})
|
|
assert save_checkpoint_mock.call_count == int(is_rank_zero)
|
|
|
|
|
|
def test_save_checkpoint_empty_state(tmp_path):
|
|
"""Test that one can save an empty state with the base implementation in Strategy."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
save_checkpoint_mock = Mock()
|
|
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
|
|
|
|
state = {}
|
|
strategy.save_checkpoint(tmp_path, state)
|
|
save_checkpoint_mock.assert_called_with(checkpoint=state, path=tmp_path, storage_options=None)
|
|
|
|
|
|
def test_save_checkpoint_convert_stateful_objects(tmp_path):
|
|
"""Test that when modules and optimizers are at the top-level in the state, their `state_dict()` gets used."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
save_checkpoint_mock = Mock()
|
|
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
|
|
|
|
anything = {"cocofruit": 1}
|
|
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "anything": anything}
|
|
expected = {
|
|
"model": model.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict(),
|
|
"anything": anything,
|
|
}
|
|
strategy.save_checkpoint(tmp_path, state)
|
|
assert save_checkpoint_mock.call_args[1]["checkpoint"].keys() == expected.keys()
|
|
saved_model_state = save_checkpoint_mock.call_args[1]["checkpoint"]["model"]
|
|
assert all(torch.equal(p0, p1) for p0, p1 in zip(saved_model_state.values(), expected["model"].values()))
|
|
assert save_checkpoint_mock.call_args[1]["checkpoint"]["optimizer"] == expected["optimizer"]
|
|
assert save_checkpoint_mock.call_args[1]["checkpoint"]["scheduler"] == expected["scheduler"]
|
|
assert save_checkpoint_mock.call_args[1]["checkpoint"]["anything"] == expected["anything"]
|
|
|
|
|
|
def test_save_load_stateful_objects(tmp_path):
|
|
"""Test that stateful objects other than modules and optimizers get converted and loaded correctly."""
|
|
|
|
class Fruit:
|
|
count = 1
|
|
|
|
def state_dict(self):
|
|
return {"count": self.count}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.count = state_dict["count"]
|
|
|
|
state = Fruit()
|
|
state.count = 100
|
|
assert isinstance(state, _Stateful)
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", {"state": state})
|
|
state = Fruit()
|
|
assert state.count == 1
|
|
strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", {"state": state})
|
|
assert state.count == 100
|
|
|
|
|
|
def test_load_module_state_dict():
|
|
"""Test that `Strategy.load_module_state_dict()` calls `.load_state_dict()` on the module."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
module = Mock()
|
|
state_dict = Mock()
|
|
strategy.load_module_state_dict(module, state_dict)
|
|
module.load_state_dict.assert_called_with(state_dict, strict=True)
|
|
strategy.load_module_state_dict(module, state_dict, strict=False)
|
|
module.load_state_dict.assert_called_with(state_dict, strict=False)
|
|
|
|
|
|
def test_load_checkpoint_model_optimizer_from_raw_checkpoint(tmp_path):
|
|
"""Test that the `load_checkpoint` can load raw state dict checkpoints too."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
|
|
torch.save(model.state_dict(), tmp_path / "model.ckpt")
|
|
torch.save(optimizer.state_dict(), tmp_path / "optimizer.ckpt")
|
|
|
|
new_model = torch.nn.Linear(3, 3)
|
|
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=2.0)
|
|
|
|
strategy.load_checkpoint(tmp_path / "model.ckpt", state=new_model, strict=False)
|
|
assert torch.equal(new_model.weight, model.weight)
|
|
strategy.load_checkpoint(tmp_path / "optimizer.ckpt", state=new_optimizer, strict=False)
|
|
assert new_optimizer.state_dict()["param_groups"][0]["lr"] == 1.0
|
|
|
|
|
|
def test_load_checkpoint_out_of_place(tmp_path):
|
|
"""Test that one can load the full checkpoint into memory just like `torch.load()`."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
load_checkpoint_mock = Mock()
|
|
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
|
|
|
|
checkpoint = strategy.load_checkpoint(tmp_path, state=None)
|
|
assert checkpoint == load_checkpoint_mock()
|
|
|
|
checkpoint = strategy.load_checkpoint(tmp_path, state={})
|
|
assert checkpoint == load_checkpoint_mock()
|
|
|
|
|
|
def test_load_checkpoint_in_place(tmp_path):
|
|
"""Test that the object's state gets reloaded in-place."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
|
|
# objects with initial state
|
|
saved_model = torch.nn.Linear(2, 2)
|
|
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
|
|
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "dict": {"cocofruit": 2}}
|
|
strategy.save_checkpoint(tmp_path / "checkpoint", state=saved_state)
|
|
|
|
# same objects with different state
|
|
model = torch.nn.Linear(2, 2)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
|
|
state = {"model": model, "optimizer": optimizer, "int": 10, "dict": {"cocofruit": 20}}
|
|
assert not torch.equal(model.weight, saved_model.weight)
|
|
assert optimizer.state_dict() != saved_optimizer.state_dict()
|
|
|
|
remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state)
|
|
assert torch.equal(model.weight, saved_model.weight)
|
|
assert optimizer.state_dict() == saved_optimizer.state_dict()
|
|
assert state["int"] == saved_state["int"]
|
|
assert state["dict"] == saved_state["dict"]
|
|
assert not remainder
|
|
|
|
# partial load - only model, no optimizer
|
|
model = torch.nn.Linear(2, 2)
|
|
state = {"model": model}
|
|
remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state)
|
|
assert torch.equal(model.weight, saved_model.weight)
|
|
assert list(remainder.keys()) == ["optimizer", "int", "dict"]
|
|
|
|
|
|
def test_load_checkpoint_strict_loading(tmp_path):
|
|
"""Test that an error is raised if a key is requested to be restored but does not exist in the checkpoint."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
saved_state = {"a": 1, "b": 2}
|
|
requested_state = {"a": 1, "b": 2, "c": 3} # key `c` does not exist in the saved state
|
|
load_checkpoint_mock = Mock(return_value=saved_state)
|
|
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
|
|
with pytest.raises(KeyError, match="contains a key 'c' that does not exist"):
|
|
strategy.load_checkpoint(tmp_path, requested_state, strict=True)
|
|
|
|
|
|
def test_load_checkpoint_non_strict_loading(tmp_path):
|
|
"""Test that no error is raised if `strict=False` and state is requested that does not exist in the checkpoint."""
|
|
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
|
|
|
|
# objects with initial state
|
|
saved_model = torch.nn.Linear(2, 2)
|
|
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
|
|
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"}
|
|
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state)
|
|
|
|
# same objects with different state
|
|
model = torch.nn.Linear(2, 2)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
|
|
state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"}
|
|
assert not torch.equal(model.weight, saved_model.weight)
|
|
assert optimizer.state_dict() != saved_optimizer.state_dict()
|
|
|
|
remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False)
|
|
assert torch.equal(model.weight, saved_model.weight)
|
|
assert optimizer.state_dict() == saved_optimizer.state_dict()
|
|
assert state["int"] == saved_state["int"]
|
|
assert "str" not in state
|
|
assert "str" in remainder
|
|
assert state["new"] == "not_present_in_saved_state"
|
|
assert "new" not in remainder
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"device",
|
|
[
|
|
"cpu",
|
|
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
pytest.param("mps:0", marks=RunIf(mps=True)),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("precision", "dtype"),
|
|
[
|
|
(Precision(), torch.float32),
|
|
(HalfPrecision("16-true"), torch.float16),
|
|
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
|
|
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("empty_init", [None, True, False])
|
|
def test_module_init_context(device, precision, dtype, empty_init, monkeypatch):
|
|
"""Test that the module under the init-module-context gets moved to the right device and dtype."""
|
|
init_mock = Mock()
|
|
monkeypatch.setattr(torch.Tensor, "uniform_", init_mock)
|
|
|
|
device = torch.device(device)
|
|
strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class
|
|
with strategy.module_init_context(empty_init=empty_init):
|
|
module = torch.nn.Linear(2, 2)
|
|
|
|
assert module.weight.device == module.bias.device == device
|
|
assert module.weight.dtype == module.bias.dtype == dtype
|
|
if not empty_init:
|
|
init_mock.assert_called()
|
|
else:
|
|
init_mock.assert_not_called()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"device",
|
|
[
|
|
"cpu",
|
|
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
pytest.param("mps:0", marks=RunIf(mps=True)),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("precision", "dtype"),
|
|
[
|
|
(Precision(), torch.float32),
|
|
(HalfPrecision("16-true"), torch.float16),
|
|
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
|
|
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
|
|
],
|
|
)
|
|
def test_tensor_init_context(device, precision, dtype):
|
|
"""Test that tensors under the init-tensor-context get moved to the right device and dtype."""
|
|
device = torch.device(device)
|
|
strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class
|
|
with strategy.tensor_init_context():
|
|
tensor0 = torch.tensor(42.0)
|
|
tensor1 = torch.tensor(42)
|
|
tensor2 = torch.tensor(42.0, dtype=torch.half)
|
|
|
|
assert tensor0.device == tensor1.device == tensor2.device == device
|
|
assert tensor0.dtype == dtype
|
|
assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes
|
|
assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment
|