Update PyTorch 2.4 tests (#20079)

This commit is contained in:
awaelchli 2024-07-13 11:09:09 +02:00 committed by GitHub
parent d5ae9ec568
commit 7d1a70752f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 9 additions and 3 deletions

View File

@ -458,7 +458,7 @@ def _load_checkpoint(
return metadata
if _is_full_checkpoint(path):
checkpoint = torch.load(path, mmap=True, map_location="cpu")
checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
state_dict_options = StateDictOptions(

View File

@ -225,7 +225,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
model = MyModel()
ckpt_path = tmp_path / "foo.ckpt"
torch.save(state_dict, ckpt_path)
torch.load(str(ckpt_path), mmap=True)
torch.load(str(ckpt_path), mmap=True, weights_only=True)
keys = model.load_state_dict(state_dict, strict=True, assign=True) # quantizes
assert not keys.missing_keys
assert model.l.weight.device.type == "cuda"
@ -258,7 +258,7 @@ def test_load_quantized_checkpoint(tmp_path):
fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
model = Model()
model = fabric.setup(model)
state_dict = torch.load(tmp_path / "checkpoint.pt")
state_dict = torch.load(tmp_path / "checkpoint.pt", weights_only=True)
model.load_state_dict(state_dict)
assert model.linear.weight.dtype == torch.uint8
assert model.linear.weight.shape == (128, 1)

View File

@ -301,6 +301,7 @@ def test_train_save_load(precision, tmp_path):
assert state["coconut"] == 11
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
"""Test that ModelParallelStrategy saves the full state into a single file with
@ -401,6 +402,7 @@ def test_save_full_state_dict(tmp_path):
_train(fabric, model, optimizer)
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a distributed model."""

View File

@ -282,6 +282,8 @@ class BoringZeroRedundancyOptimizerModel(BoringModel):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)
# ZeroRedundancyOptimizer internally calls `torch.load` with `weights_only` not set, triggering the FutureWarning
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True)
@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"])
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(strategy, tmp_path):

View File

@ -418,6 +418,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path):
trainer.strategy.barrier()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_load_standard_checkpoint_into_distributed_model(tmp_path):
"""Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model."""
@ -458,6 +459,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path):
trainer.strategy.barrier()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_load_sharded_state_dict(tmp_path):
"""Test saving and loading with the distributed state dict format."""