# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock from unittest.mock import Mock, PropertyMock import pytest import torch from torch import nn from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from tests_fabric.helpers.runif import RunIf @pytest.mark.parametrize("is_rank_zero", [True, False]) def test_save_checkpoint_rank_zero_only(is_rank_zero, tmp_path): """Test that the checkpoint only gets saved on global rank 0 in the base implementation in Strategy.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class save_checkpoint_mock = Mock() strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock with mock.patch( "lightning.fabric.strategies.single_device.SingleDeviceStrategy.is_global_zero", new_callable=PropertyMock(return_value=is_rank_zero), ): strategy.save_checkpoint(tmp_path, {"anything": 1}) assert save_checkpoint_mock.call_count == int(is_rank_zero) def test_save_checkpoint_empty_state(tmp_path): """Test that one can save an empty state with the base implementation in Strategy.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class save_checkpoint_mock = Mock() strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock state = {} strategy.save_checkpoint(tmp_path, state) save_checkpoint_mock.assert_called_with(checkpoint=state, path=tmp_path, storage_options=None) def test_save_checkpoint_convert_stateful_objects(tmp_path): """Test that when modules and optimizers are at the top-level in the state, their `state_dict()` gets used.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class save_checkpoint_mock = Mock() strategy.checkpoint_io.save_checkpoint = save_checkpoint_mock model = nn.Linear(3, 3) optimizer = torch.optim.Adam(model.parameters()) anything = {"cocofruit": 1} state = {"model": model, "optimizer": optimizer, "anything": anything} expected = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "anything": anything} strategy.save_checkpoint(tmp_path, state) assert save_checkpoint_mock.call_args[1]["checkpoint"].keys() == expected.keys() saved_model_state = save_checkpoint_mock.call_args[1]["checkpoint"]["model"] assert all(torch.equal(p0, p1) for p0, p1 in zip(saved_model_state.values(), expected["model"].values())) assert save_checkpoint_mock.call_args[1]["checkpoint"]["optimizer"] == expected["optimizer"] assert save_checkpoint_mock.call_args[1]["checkpoint"]["anything"] == expected["anything"] def test_load_module_state_dict(): """Test that `Strategy.load_module_state_dict()` calls `.load_state_dict()` on the module.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class module = Mock() state_dict = Mock() strategy.load_module_state_dict(module, state_dict) module.load_state_dict.assert_called_with(state_dict, strict=True) strategy.load_module_state_dict(module, state_dict, strict=False) module.load_state_dict.assert_called_with(state_dict, strict=False) def test_load_checkpoint_model_optimizer_from_raw_checkpoint(tmp_path): """Test that the `load_checkpoint` can load raw state dict checkpoints too.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class model = nn.Linear(3, 3) optimizer = torch.optim.Adam(model.parameters(), lr=1.0) torch.save(model.state_dict(), tmp_path / "model.ckpt") torch.save(optimizer.state_dict(), tmp_path / "optimizer.ckpt") new_model = nn.Linear(3, 3) new_optimizer = torch.optim.Adam(new_model.parameters(), lr=2.0) strategy.load_checkpoint(tmp_path / "model.ckpt", state=new_model, strict=False) assert torch.equal(new_model.weight, model.weight) strategy.load_checkpoint(tmp_path / "optimizer.ckpt", state=new_optimizer, strict=False) assert new_optimizer.state_dict()["param_groups"][0]["lr"] == 1.0 def test_load_checkpoint_out_of_place(tmp_path): """Test that one can load the full checkpoint into memory just like `torch.load()`.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class load_checkpoint_mock = Mock() strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock checkpoint = strategy.load_checkpoint(tmp_path, state=None) assert checkpoint == load_checkpoint_mock() checkpoint = strategy.load_checkpoint(tmp_path, state={}) assert checkpoint == load_checkpoint_mock() def test_load_checkpoint_in_place(tmp_path): """Test that the object's state gets reloaded in-place.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class # objects with initial state saved_model = nn.Linear(2, 2) saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1) saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "dict": {"cocofruit": 2}} strategy.save_checkpoint(tmp_path / "checkpoint", state=saved_state) # same objects with different state model = nn.Linear(2, 2) optimizer = torch.optim.Adam(model.parameters(), lr=0.3) state = {"model": model, "optimizer": optimizer, "int": 10, "dict": {"cocofruit": 20}} assert not torch.equal(model.weight, saved_model.weight) assert optimizer.state_dict() != saved_optimizer.state_dict() remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state) assert torch.equal(model.weight, saved_model.weight) assert optimizer.state_dict() == saved_optimizer.state_dict() assert state["int"] == saved_state["int"] assert state["dict"] == saved_state["dict"] assert not remainder # partial load - only model, no optimizer model = nn.Linear(2, 2) state = {"model": model} remainder = strategy.load_checkpoint(tmp_path / "checkpoint", state) assert torch.equal(model.weight, saved_model.weight) assert list(remainder.keys()) == ["optimizer", "int", "dict"] def test_load_checkpoint_strict_loading(tmp_path): """Test that an error is raised if a key is requested to be restored but does not exist in the checkpoint.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class saved_state = {"a": 1, "b": 2} requested_state = {"a": 1, "b": 2, "c": 3} # key `c` does not exist in the saved state load_checkpoint_mock = Mock(return_value=saved_state) strategy.checkpoint_io.load_checkpoint = load_checkpoint_mock with pytest.raises(KeyError, match="contains a key 'c' that does not exist"): strategy.load_checkpoint(tmp_path, requested_state, strict=True) def test_load_checkpoint_non_strict_loading(tmp_path): """Test that no error is raised if `strict=False` and state is requested that does not exist in the checkpoint.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class # objects with initial state saved_model = nn.Linear(2, 2) saved_optimizer = torch.optim.Adam(saved_model.parameters(), lr=0.1) saved_state = {"model": saved_model, "optimizer": saved_optimizer, "int": 1, "str": "test"} strategy.save_checkpoint(tmp_path / "checkpoint.ckpt", state=saved_state) # same objects with different state model = nn.Linear(2, 2) optimizer = torch.optim.Adam(model.parameters(), lr=0.3) state = {"model": model, "optimizer": optimizer, "int": 2, "new": "not_present_in_saved_state"} assert not torch.equal(model.weight, saved_model.weight) assert optimizer.state_dict() != saved_optimizer.state_dict() remainder = strategy.load_checkpoint(tmp_path / "checkpoint.ckpt", state, strict=False) assert torch.equal(model.weight, saved_model.weight) assert optimizer.state_dict() == saved_optimizer.state_dict() assert state["int"] == saved_state["int"] assert "str" not in state assert "str" in remainder assert state["new"] == "not_present_in_saved_state" assert "new" not in remainder @RunIf(min_torch="1.13") @pytest.mark.parametrize( "device", [ "cpu", pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps:0", marks=RunIf(mps=True)), ], ) @pytest.mark.parametrize( ("precision", "dtype"), [ (Precision(), torch.float32), (HalfPrecision("16-true"), torch.float16), pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)), pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), ], ) @pytest.mark.parametrize("empty_init", [None, True, False]) def test_module_init_context(device, precision, dtype, empty_init, monkeypatch): """Test that the module under the init-module-context gets moved to the right device and dtype.""" init_mock = Mock() monkeypatch.setattr(torch.Tensor, "uniform_", init_mock) device = torch.device(device) strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class with strategy.module_init_context(empty_init=empty_init): module = torch.nn.Linear(2, 2) expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") assert module.weight.device == module.bias.device == expected_device assert module.weight.dtype == module.bias.dtype == dtype if not empty_init: init_mock.assert_called() else: init_mock.assert_not_called() @pytest.mark.parametrize( "device", [ "cpu", pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps:0", marks=RunIf(mps=True)), ], ) @pytest.mark.parametrize( ("precision", "dtype"), [ (Precision(), torch.float32), (HalfPrecision("16-true"), torch.float16), pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)), pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), ], ) def test_tensor_init_context(device, precision, dtype): """Test that tensors under the init-tensor-context get moved to the right device and dtype.""" device = torch.device(device) strategy = SingleDeviceStrategy(device=device, precision=precision) # surrogate class to test base class with strategy.tensor_init_context(): tensor0 = torch.tensor(42.0) tensor1 = torch.tensor(42) tensor2 = torch.tensor(42.0, dtype=torch.half) expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu") assert tensor0.device == tensor1.device == tensor2.device == expected_device assert tensor0.dtype == dtype assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment