From 89ff87def0ea132799e1fb36851974d49101f573 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 1 Feb 2024 21:06:18 +0100 Subject: [PATCH] Reapply compile in `Fabric.setup()` by default (#19382) --- docs/source-fabric/advanced/compile.rst | 55 +++++++++++++------------ src/lightning/fabric/fabric.py | 15 +++---- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 3e47991675..703e3c51ef 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -220,6 +220,7 @@ On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically a Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1. + ---- @@ -255,32 +256,6 @@ Naturally, the tradoff here is that it will consume a bit more memory. You can find a full list of compile options in the `PyTorch documentation `_. ----- - - -******************************************************* -(Experimental) Apply torch.compile over FSDP, DDP, etc. -******************************************************* - -As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``. -However, if you are using DDP or FSDP with Fabric, the compilation won't incorporate the distributed calls inside these wrappers by default. -In an experimental feature, you can let ``fabric.setup()`` reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally. -In the future, this option will become the default. - -.. code-block:: python - - # Choose a distributed strategy like DDP or FSDP - fabric = L.Fabric(devices=2, strategy="ddp") - - # Compile the model - model = torch.compile(model) - - # Default: `fabric.setup()` will not reapply the compilation over DDP/FSDP - model = fabric.setup(model, _reapply_compile=False) - - # Recompile the model over DDP/FSDP (experimental) - model = fabric.setup(model, _reapply_compile=True) - ---- @@ -296,4 +271,32 @@ On top of that, the compilation phase itself can be incredibly slow, taking seve For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments. Always compare the speed and memory usage of the compiled model against the original model! + +---- + + +************************************* +Using torch.compile with FSDP and DDP +************************************* + +As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``. +In the case of DDP and FSDP, ``fabric.setup()`` will automatically reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally. +This will ensure that the compilation can incorporate the distributed calls and optimize them. +However, should you have issues compiling DDP and FSDP models, you can opt out of this feature: + +.. code-block:: python + + # Choose a distributed strategy like DDP or FSDP + fabric = L.Fabric(devices=2, strategy="ddp") + + # Compile the model + model = torch.compile(model) + + # Default: `fabric.setup()` will configure compilation over DDP/FSDP for you + model = fabric.setup(model, _reapply_compile=True) + + # Turn it off if you see issues with DDP/FSDP + model = fabric.setup(model, _reapply_compile=False) + + | diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 3eb12f2afa..bc07e633a9 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -214,7 +214,7 @@ class Fabric: module: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, - _reapply_compile: Optional[bool] = None, + _reapply_compile: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy r"""Set up a model and its optimizers for accelerated training. @@ -223,10 +223,11 @@ class Fabric: *optimizers: The optimizer(s) to set up (no optimizers is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. - _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the + _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, - FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future. + FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing + issues. Returns: The tuple containing wrapped module and the optimizers, in the same order they were passed in. @@ -280,7 +281,7 @@ class Fabric: return module def setup_module( - self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None + self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True ) -> _FabricModule: r"""Set up a model for accelerated training or inference. @@ -292,11 +293,11 @@ class Fabric: module: A :class:`torch.nn.Module` to set up move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. - _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the + _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, - FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future. - + FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing + issues. Returns: The wrapped model.