From d99625fc8d558a5c47206e1b75928c159f0d1a27 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 21 Mar 2022 10:18:59 -0700 Subject: [PATCH] Reduce number of times optimizers are instantiated with FSDP (#12267) --- CHANGELOG.md | 5 ++++- pytorch_lightning/strategies/fully_sharded.py | 10 +++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 604a3396de..541b3dcc1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -812,7 +812,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131)) -- Fixed the case where logger=None is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) +- Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) + + +- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) ## [1.5.10] - 2022-02-08 diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index 3b43520cf0..b61429264d 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -138,9 +138,6 @@ class DDPFullyShardedStrategy(DDPStrategy): def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) - self.setup_optimizers(trainer) - self.setup_precision_plugin() - optimizers_to_device(self.optimizers, self.root_device) if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: self.model = self._layer_sync.apply(self.model) @@ -148,6 +145,8 @@ class DDPFullyShardedStrategy(DDPStrategy): self.configure_ddp() self.barrier() self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() @contextlib.contextmanager def model_sharded_context(self) -> Generator: @@ -176,7 +175,7 @@ class DDPFullyShardedStrategy(DDPStrategy): log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.") def configure_ddp(self) -> None: - log.detail(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])") + log.detail(f"{self.__class__.__name__}: configuring FSDP... (cpu_offload: [{self.cpu_offload}])") if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us. # Note: this would be problematic for large model (which could not fit in one GPU) @@ -184,9 +183,6 @@ class DDPFullyShardedStrategy(DDPStrategy): # (TODO: need to figure out solution) self.model_to_device() - # setup optimizers after fully sharded has wrapped the lightning module - self.setup_optimizers(self.lightning_module.trainer) - def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") # ensure we update the device type in the lightning module