Mention specific param names and devices in warning (#18273)

This commit is contained in:
Carlos Mocholí 2023-08-16 00:51:28 +02:00 committed by GitHub
parent c98fb36b11
commit dc44fa406a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -925,14 +925,28 @@ class Fabric:
return to_run(*args, **kwargs)
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
initial_device = next(model.parameters(), torch.tensor(0)).device
if any(param.device != initial_device for param in model.parameters()):
rank_zero_warn(
"The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"
" all parameters will be moved to the new device. If this is not desired, set "
" `Fabric.setup(..., move_to_device=False)`.",
category=PossibleUserWarning,
)
try:
initial_name, initial_param = next(model.named_parameters())
except StopIteration:
pass
else:
initial_device = initial_param.device
count = 0
first_name, first_device = None, None
for name, param in model.named_parameters():
if param.device != initial_device:
count += 1
if first_name is None:
first_name = name
first_device = param.device
if count > 0:
rank_zero_warn(
f"The model passed to `Fabric.setup()` has {count} parameters on different devices (for example"
f" {first_name!r} on {first_device} and {initial_name!r} on {initial_device}). Since"
" `move_to_device=True`, all parameters will be moved to the new device. If this is not"
" desired, set `Fabric.setup(..., move_to_device=False)`.",
category=PossibleUserWarning,
)
if isinstance(self._strategy, XLAStrategy):
# When the user creates the optimizer, they reference the parameters on the CPU.

View File

@ -161,8 +161,9 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
setup_method = getattr(fabric, setup_method)
match = r"has 2 parameters on different devices \(for example '1.weight' on cuda:0 and '0.weight' on cpu\)"
if move_to_device:
with pytest.warns(PossibleUserWarning, match="has parameters on different devices"):
with pytest.warns(PossibleUserWarning, match=match):
fabric_model = setup_method(model, move_to_device=move_to_device)
# both have the same device now
@ -170,7 +171,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
assert module0.weight.device == module0.bias.device == device1
assert module1.weight.device == module1.bias.device == device1
else:
with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"):
with no_warning_call(expected_warning=PossibleUserWarning, match=match):
setup_method(model, move_to_device=move_to_device)