Reduce number of times optimizers are instantiated with FSDP (#12267)

This commit is contained in:
ananthsub 2022-03-21 10:18:59 -07:00 committed by GitHub
parent fa7aa0babe
commit d99625fc8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 8 deletions

View File

@ -812,7 +812,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131)) - Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131))
- Fixed the case where logger=None is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) - Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249))
- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))
## [1.5.10] - 2022-02-08 ## [1.5.10] - 2022-02-08

View File

@ -138,9 +138,6 @@ class DDPFullyShardedStrategy(DDPStrategy):
def setup(self, trainer: "pl.Trainer") -> None: def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer) self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
optimizers_to_device(self.optimizers, self.root_device)
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
self.model = self._layer_sync.apply(self.model) self.model = self._layer_sync.apply(self.model)
@ -148,6 +145,8 @@ class DDPFullyShardedStrategy(DDPStrategy):
self.configure_ddp() self.configure_ddp()
self.barrier() self.barrier()
self.setup_optimizers(trainer) self.setup_optimizers(trainer)
optimizers_to_device(self.optimizers, self.root_device)
self.setup_precision_plugin()
@contextlib.contextmanager @contextlib.contextmanager
def model_sharded_context(self) -> Generator: def model_sharded_context(self) -> Generator:
@ -176,7 +175,7 @@ class DDPFullyShardedStrategy(DDPStrategy):
log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.") log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.")
def configure_ddp(self) -> None: def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])") log.detail(f"{self.__class__.__name__}: configuring FSDP... (cpu_offload: [{self.cpu_offload}])")
if not self.cpu_offload: if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us. # When using CPU Offload, FSDP will manage the CUDA movement for us.
# Note: this would be problematic for large model (which could not fit in one GPU) # Note: this would be problematic for large model (which could not fit in one GPU)
@ -184,9 +183,6 @@ class DDPFullyShardedStrategy(DDPStrategy):
# (TODO: need to figure out solution) # (TODO: need to figure out solution)
self.model_to_device() self.model_to_device()
# setup optimizers after fully sharded has wrapped the lightning module
self.setup_optimizers(self.lightning_module.trainer)
def model_to_device(self) -> None: def model_to_device(self) -> None:
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module # ensure we update the device type in the lightning module