Move torchmetrics to device when using FSDP (#18954)

This commit is contained in:
Adrian Wälchli 2023-11-08 21:29:26 +01:00 committed by GitHub
parent 07461a1def
commit 964364b3bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 0 deletions

View File

@ -33,6 +33,7 @@ from typing import (
)
import torch
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
**self._fsdp_kwargs,
)
_move_torchmetrics_to_device(module, self.root_device)
# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(module, self._activation_checkpointing_kwargs)
@ -886,3 +889,15 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
if isinstance(obj, Module):
return any(t.is_meta for t in obj.parameters())
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
# https://github.com/pytorch/pytorch/issues/113113
if not RequirementCache("torchmetrics"):
return
from torchmetrics import Metric
for metric in (m for m in module.modules() if isinstance(m, Metric)):
metric.to(device) # `.to()` is in-place

View File

@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/lightning/issues/18942))
- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))
## [2.1.0] - 2023-10-11
### Added

View File

@ -39,6 +39,7 @@ from lightning.fabric.strategies.fsdp import (
_is_full_checkpoint,
_is_sharded_checkpoint,
_load_raw_module_state,
_move_torchmetrics_to_device,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
@ -292,6 +293,8 @@ class FSDPStrategy(ParallelStrategy):
**self.kwargs,
)
_move_torchmetrics_to_device(model, self.root_device)
# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(model, self._activation_checkpointing_kwargs)

View File

@ -28,6 +28,7 @@ from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy
from tests_pytorch.helpers.runif import RunIf
@ -239,6 +240,36 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir):
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
@RunIf(min_cuda_gpus=1, skip_windows=True)
def test_fsdp_modules_without_parameters(tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
class MetricsModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = Accuracy("multiclass", num_classes=10)
assert self.metric.device == self.metric.tp.device == torch.device("cpu")
def setup(self, stage) -> None:
assert self.metric.device == self.metric.tp.device == torch.device("cpu")
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0)
self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device))
return loss
model = MetricsModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=1,
strategy="fsdp",
max_steps=1,
)
trainer.fit(model)
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):