diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index f8dd017cc1..6880c9784f 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -458,7 +458,7 @@ def _load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = torch.load(path, mmap=True, map_location="cpu") + checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False) _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) state_dict_options = StateDictOptions( diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index 31d655fc4c..d531161682 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -225,7 +225,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path): model = MyModel() ckpt_path = tmp_path / "foo.ckpt" torch.save(state_dict, ckpt_path) - torch.load(str(ckpt_path), mmap=True) + torch.load(str(ckpt_path), mmap=True, weights_only=True) keys = model.load_state_dict(state_dict, strict=True, assign=True) # quantizes assert not keys.missing_keys assert model.l.weight.device.type == "cuda" @@ -258,7 +258,7 @@ def test_load_quantized_checkpoint(tmp_path): fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq")) model = Model() model = fabric.setup(model) - state_dict = torch.load(tmp_path / "checkpoint.pt") + state_dict = torch.load(tmp_path / "checkpoint.pt", weights_only=True) model.load_state_dict(state_dict) assert model.linear.weight.dtype == torch.uint8 assert model.linear.weight.shape == (128, 1) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 73614b02d3..dfbdb16b10 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -301,6 +301,7 @@ def test_train_save_load(precision, tmp_path): assert state["coconut"] == 11 +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_save_full_state_dict(tmp_path): """Test that ModelParallelStrategy saves the full state into a single file with @@ -401,6 +402,7 @@ def test_save_full_state_dict(tmp_path): _train(fabric, model, optimizer) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_load_full_state_dict_into_sharded_model(tmp_path): """Test that the strategy can load a full-state checkpoint into a distributed model.""" diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index bd86bfa6ce..836072d36b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -282,6 +282,8 @@ class BoringZeroRedundancyOptimizerModel(BoringModel): return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) +# ZeroRedundancyOptimizer internally calls `torch.load` with `weights_only` not set, triggering the FutureWarning +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True) @pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"]) def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(strategy, tmp_path): diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 5947277b05..57d2739175 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -418,6 +418,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path): trainer.strategy.barrier() +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) def test_load_standard_checkpoint_into_distributed_model(tmp_path): """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model.""" @@ -458,6 +459,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path): trainer.strategy.barrier() +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_save_load_sharded_state_dict(tmp_path): """Test saving and loading with the distributed state dict format."""