From 964364b3bbfd27c448633025ae7919ef613b83ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Nov 2023 21:29:26 +0100 Subject: [PATCH] Move torchmetrics to device when using FSDP (#18954) --- src/lightning/fabric/strategies/fsdp.py | 15 ++++++++++ src/lightning/pytorch/CHANGELOG.md | 3 ++ src/lightning/pytorch/strategies/fsdp.py | 3 ++ tests/tests_pytorch/strategies/test_fsdp.py | 31 +++++++++++++++++++++ 4 files changed, 52 insertions(+) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 508fa317d1..ae125cb568 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -33,6 +33,7 @@ from typing import ( ) import torch +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module, Parameter from torch.optim import Optimizer @@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded): **self._fsdp_kwargs, ) + _move_torchmetrics_to_device(module, self.root_device) + # activation checkpointing needs to be set up after wrapping the model if _TORCH_GREATER_EQUAL_1_13: _setup_activation_checkpointing(module, self._activation_checkpointing_kwargs) @@ -886,3 +889,15 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool: if isinstance(obj, Module): return any(t.is_meta for t in obj.parameters()) raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") + + +def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None: + # FSDP doesn't move modules without parameters (e.g. Metrics) to the device + # https://github.com/pytorch/pytorch/issues/113113 + if not RequirementCache("torchmetrics"): + return + + from torchmetrics import Metric + + for metric in (m for m in module.modules() if isinstance(m, Metric)): + metric.to(device) # `.to()` is in-place diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 57824dc66c..99e477bee7 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/lightning/issues/18942)) +- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954)) + + ## [2.1.0] - 2023-10-11 ### Added diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 907497d218..1407b61d08 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -39,6 +39,7 @@ from lightning.fabric.strategies.fsdp import ( _is_full_checkpoint, _is_sharded_checkpoint, _load_raw_module_state, + _move_torchmetrics_to_device, _optimizer_has_flat_params, _setup_activation_checkpointing, ) @@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy): **self.kwargs, ) + _move_torchmetrics_to_device(model, self.root_device) + # activation checkpointing needs to be set up after wrapping the model if _TORCH_GREATER_EQUAL_1_13: _setup_activation_checkpointing(model, self._activation_checkpointing_kwargs) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f094053e98..bcb338e05e 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -28,6 +28,7 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap +from torchmetrics import Accuracy from tests_pytorch.helpers.runif import RunIf @@ -239,6 +240,36 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) +@RunIf(min_cuda_gpus=1, skip_windows=True) +def test_fsdp_modules_without_parameters(tmp_path): + """Test that TorchMetrics get moved to the device despite not having any parameters.""" + + class MetricsModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = Accuracy("multiclass", num_classes=10) + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def setup(self, stage) -> None: + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0) + self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device)) + return loss + + model = MetricsModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=1, + strategy="fsdp", + max_steps=1, + ) + trainer.fit(model) + + @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) def test_fsdp_strategy_checkpoint(tmpdir, precision):