Move torchmetrics to device when using FSDP (#18954)
This commit is contained in:
parent
07461a1def
commit
964364b3bb
|
@ -33,6 +33,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
**self._fsdp_kwargs,
|
||||
)
|
||||
|
||||
_move_torchmetrics_to_device(module, self.root_device)
|
||||
|
||||
# activation checkpointing needs to be set up after wrapping the model
|
||||
if _TORCH_GREATER_EQUAL_1_13:
|
||||
_setup_activation_checkpointing(module, self._activation_checkpointing_kwargs)
|
||||
|
@ -886,3 +889,15 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
|
|||
if isinstance(obj, Module):
|
||||
return any(t.is_meta for t in obj.parameters())
|
||||
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
|
||||
|
||||
|
||||
def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
|
||||
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
|
||||
# https://github.com/pytorch/pytorch/issues/113113
|
||||
if not RequirementCache("torchmetrics"):
|
||||
return
|
||||
|
||||
from torchmetrics import Metric
|
||||
|
||||
for metric in (m for m in module.modules() if isinstance(m, Metric)):
|
||||
metric.to(device) # `.to()` is in-place
|
||||
|
|
|
@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/lightning/issues/18942))
|
||||
|
||||
|
||||
- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))
|
||||
|
||||
|
||||
## [2.1.0] - 2023-10-11
|
||||
|
||||
### Added
|
||||
|
|
|
@ -39,6 +39,7 @@ from lightning.fabric.strategies.fsdp import (
|
|||
_is_full_checkpoint,
|
||||
_is_sharded_checkpoint,
|
||||
_load_raw_module_state,
|
||||
_move_torchmetrics_to_device,
|
||||
_optimizer_has_flat_params,
|
||||
_setup_activation_checkpointing,
|
||||
)
|
||||
|
@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy):
|
|||
**self.kwargs,
|
||||
)
|
||||
|
||||
_move_torchmetrics_to_device(model, self.root_device)
|
||||
|
||||
# activation checkpointing needs to be set up after wrapping the model
|
||||
if _TORCH_GREATER_EQUAL_1_13:
|
||||
_setup_activation_checkpointing(model, self._activation_checkpointing_kwargs)
|
||||
|
|
|
@ -28,6 +28,7 @@ from lightning.pytorch.trainer.states import TrainerFn
|
|||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
@ -239,6 +240,36 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir):
|
|||
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True)
|
||||
def test_fsdp_modules_without_parameters(tmp_path):
|
||||
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
|
||||
|
||||
class MetricsModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Accuracy("multiclass", num_classes=10)
|
||||
assert self.metric.device == self.metric.tp.device == torch.device("cpu")
|
||||
|
||||
def setup(self, stage) -> None:
|
||||
assert self.metric.device == self.metric.tp.device == torch.device("cpu")
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = super().training_step(batch, batch_idx)
|
||||
assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0)
|
||||
self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device))
|
||||
return loss
|
||||
|
||||
model = MetricsModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
accelerator="cuda",
|
||||
devices=1,
|
||||
strategy="fsdp",
|
||||
max_steps=1,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
||||
def test_fsdp_strategy_checkpoint(tmpdir, precision):
|
||||
|
|
Loading…
Reference in New Issue