Mention specific param names and devices in warning (#18273)
This commit is contained in:
parent
c98fb36b11
commit
dc44fa406a
|
@ -925,14 +925,28 @@ class Fabric:
|
|||
return to_run(*args, **kwargs)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
|
||||
initial_device = next(model.parameters(), torch.tensor(0)).device
|
||||
if any(param.device != initial_device for param in model.parameters()):
|
||||
rank_zero_warn(
|
||||
"The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"
|
||||
" all parameters will be moved to the new device. If this is not desired, set "
|
||||
" `Fabric.setup(..., move_to_device=False)`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
try:
|
||||
initial_name, initial_param = next(model.named_parameters())
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
initial_device = initial_param.device
|
||||
count = 0
|
||||
first_name, first_device = None, None
|
||||
for name, param in model.named_parameters():
|
||||
if param.device != initial_device:
|
||||
count += 1
|
||||
if first_name is None:
|
||||
first_name = name
|
||||
first_device = param.device
|
||||
if count > 0:
|
||||
rank_zero_warn(
|
||||
f"The model passed to `Fabric.setup()` has {count} parameters on different devices (for example"
|
||||
f" {first_name!r} on {first_device} and {initial_name!r} on {initial_device}). Since"
|
||||
" `move_to_device=True`, all parameters will be moved to the new device. If this is not"
|
||||
" desired, set `Fabric.setup(..., move_to_device=False)`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
|
||||
if isinstance(self._strategy, XLAStrategy):
|
||||
# When the user creates the optimizer, they reference the parameters on the CPU.
|
||||
|
|
|
@ -161,8 +161,9 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
|
|||
|
||||
setup_method = getattr(fabric, setup_method)
|
||||
|
||||
match = r"has 2 parameters on different devices \(for example '1.weight' on cuda:0 and '0.weight' on cpu\)"
|
||||
if move_to_device:
|
||||
with pytest.warns(PossibleUserWarning, match="has parameters on different devices"):
|
||||
with pytest.warns(PossibleUserWarning, match=match):
|
||||
fabric_model = setup_method(model, move_to_device=move_to_device)
|
||||
|
||||
# both have the same device now
|
||||
|
@ -170,7 +171,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
|
|||
assert module0.weight.device == module0.bias.device == device1
|
||||
assert module1.weight.device == module1.bias.device == device1
|
||||
else:
|
||||
with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"):
|
||||
with no_warning_call(expected_warning=PossibleUserWarning, match=match):
|
||||
setup_method(model, move_to_device=move_to_device)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue