lightning/tests/tests_fabric/strategies/test_strategy.py

279 lines
12 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock, PropertyMock
import pytest
import torch
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
from lightning.fabric.strategies import SingleDeviceStrategy
from lightning.fabric.utilities.types import _Stateful
from tests_fabric.helpers.runif import RunIf
@pytest.mark.parametrize("is_rank_zero", [True, False])
def test_save_checkpoint_rank_zero_only(is_rank_zero, tmp_path):
"""Test that the checkpoint only gets saved on global rank 0 in the base implementation in Strategy."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
save_checkpoint_mock = Mock()
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
with mock.patch(
"lightning.fabric.strategies.single_device.SingleDeviceStrategy.is_global_zero",
new_callable=PropertyMock(return_value=is_rank_zero),
):
strategy.save_checkpoint(tmp_path, {"anything": 1})
assert save_checkpoint_mock.call_count == int(is_rank_zero)
def test_save_checkpoint_empty_state(tmp_path):
"""Test that one can save an empty state with the base implementation in Strategy."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
save_checkpoint_mock = Mock()
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
state = {}
strategy.save_checkpoint(tmp_path, state)
save_checkpoint_mock.assert_called_with(checkpoint=state, path=tmp_path, storage_options=None)
def test_save_checkpoint_convert_stateful_objects(tmp_path):
"""Test that when modules and optimizers are at the top-level in the state, their `state_dict()` gets used."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
save_checkpoint_mock = Mock()
strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
anything = {"cocofruit": 1}
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "anything": anything}
expected = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"anything": anything,
}
strategy.save_checkpoint(tmp_path, state)
assert save_checkpoint_mock.call_args[1]["checkpoint"].keys() == expected.keys()
saved_model_state = save_checkpoint_mock.call_args[1]["checkpoint"]["model"]
assert all(torch.equal(p0, p1) for p0, p1 in zip(saved_model_state.values(), expected["model"].values()))
assert save_checkpoint_mock.call_args[1]["checkpoint"]["optimizer"] == expected["optimizer"]
assert save_checkpoint_mock.call_args[1]["checkpoint"]["scheduler"] == expected["scheduler"]
assert save_checkpoint_mock.call_args[1]["checkpoint"]["anything"] == expected["anything"]
def test_save_load_stateful_objects(tmp_path):
"""Test that stateful objects other than modules and optimizers get converted and loaded correctly."""
class Fruit:
count = 1
def state_dict(self):
return {"count": self.count}
def load_state_dict(self, state_dict):
self.count = state_dict["count"]
state = Fruit()
state.count = 100
assert isinstance(state, _Stateful)
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", {"state": state})
state = Fruit()
assert state.count == 1
strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", {"state": state})
assert state.count == 100
def test_load_module_state_dict():
"""Test that `Strategy.load_module_state_dict()` calls `.load_state_dict()` on the module."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
module = Mock()
state_dict = Mock()
strategy.load_module_state_dict(module, state_dict)
module.load_state_dict.assert_called_with(state_dict, strict=True)
strategy.load_module_state_dict(module, state_dict, strict=False)
module.load_state_dict.assert_called_with(state_dict, strict=False)
def test_load_checkpoint_model_optimizer_from_raw_checkpoint(tmp_path):
"""Test that the `load_checkpoint` can load raw state dict checkpoints too."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
torch.save(model.state_dict(), tmp_path / "model.ckpt")
torch.save(optimizer.state_dict(), tmp_path / "optimizer.ckpt")
new_model = torch.nn.Linear(3, 3)
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=2.0)
strategy.load_checkpoint(tmp_path / "model.ckpt", state=new_model, strict=False)
assert torch.equal(new_model.weight, model.weight)
strategy.load_checkpoint(tmp_path / "optimizer.ckpt", state=new_optimizer, strict=False)
assert new_optimizer.state_dict()["param_groups"][0]["lr"] == 1.0
def test_load_checkpoint_out_of_place(tmp_path):
"""Test that one can load the full checkpoint into memory just like `torch.load()`."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
load_checkpoint_mock = Mock()
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
checkpoint = strategy.load_checkpoint(tmp_path, state=None)
assert checkpoint == load_checkpoint_mock()
checkpoint = strategy.load_checkpoint(tmp_path, state={})
assert checkpoint == load_checkpoint_mock()
def test_load_checkpoint_in_place(tmp_path):
"""Test that the object's state gets reloaded in-place."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
# objects with initial state
saved_model = torch.nn.Linear(2, 2)
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "dict": {"cocofruit": 2}}
strategy.save_checkpoint(tmp_path / "checkpoint", state=saved_state)
# same objects with different state
model = torch.nn.Linear(2, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
state = {"model": model, "optimizer": optimizer, "int": 10, "dict": {"cocofruit": 20}}
assert not torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() != saved_optimizer.state_dict()
remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state)
assert torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() == saved_optimizer.state_dict()
assert state["int"] == saved_state["int"]
assert state["dict"] == saved_state["dict"]
assert not remainder
# partial load - only model, no optimizer
model = torch.nn.Linear(2, 2)
state = {"model": model}
remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state)
assert torch.equal(model.weight, saved_model.weight)
assert list(remainder.keys()) == ["optimizer", "int", "dict"]
def test_load_checkpoint_strict_loading(tmp_path):
"""Test that an error is raised if a key is requested to be restored but does not exist in the checkpoint."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
saved_state = {"a": 1, "b": 2}
requested_state = {"a": 1, "b": 2, "c": 3} # key `c` does not exist in the saved state
load_checkpoint_mock = Mock(return_value=saved_state)
strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock
with pytest.raises(KeyError, match="contains a key 'c' that does not exist"):
strategy.load_checkpoint(tmp_path, requested_state, strict=True)
def test_load_checkpoint_non_strict_loading(tmp_path):
"""Test that no error is raised if `strict=False` and state is requested that does not exist in the checkpoint."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
# objects with initial state
saved_model = torch.nn.Linear(2, 2)
saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1)
saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"}
strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state)
# same objects with different state
model = torch.nn.Linear(2, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)
state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"}
assert not torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() != saved_optimizer.state_dict()
remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False)
assert torch.equal(model.weight, saved_model.weight)
assert optimizer.state_dict() == saved_optimizer.state_dict()
assert state["int"] == saved_state["int"]
assert "str" not in state
assert "str" in remainder
assert state["new"] == "not_present_in_saved_state"
assert "new" not in remainder
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
("precision", "dtype"),
[
(Precision(), torch.float32),
(HalfPrecision("16-true"), torch.float16),
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
@pytest.mark.parametrize("empty_init", [None, True, False])
def test_module_init_context(device, precision, dtype, empty_init, monkeypatch):
"""Test that the module under the init-module-context gets moved to the right device and dtype."""
init_mock = Mock()
monkeypatch.setattr(torch.Tensor, "uniform_", init_mock)
device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class
with strategy.module_init_context(empty_init=empty_init):
module = torch.nn.Linear(2, 2)
assert module.weight.device == module.bias.device == device
assert module.weight.dtype == module.bias.dtype == dtype
if not empty_init:
init_mock.assert_called()
else:
init_mock.assert_not_called()
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
("precision", "dtype"),
[
(Precision(), torch.float32),
(HalfPrecision("16-true"), torch.float16),
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
def test_tensor_init_context(device, precision, dtype):
"""Test that tensors under the init-tensor-context get moved to the right device and dtype."""
device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class
with strategy.tensor_init_context():
tensor0 = torch.tensor(42.0)
tensor1 = torch.tensor(42)
tensor2 = torch.tensor(42.0, dtype=torch.half)
assert tensor0.device == tensor1.device == tensor2.device == device
assert tensor0.dtype == dtype
assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes
assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment