Reapply compile in `Fabric.setup()` by default (#19382)

This commit is contained in:
awaelchli 2024-02-01 21:06:18 +01:00 committed by GitHub
parent af7e79a84a
commit 89ff87def0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 33 deletions

View File

@ -220,6 +220,7 @@ On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically a
Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1.
----
@ -255,32 +256,6 @@ Naturally, the tradoff here is that it will consume a bit more memory.
You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_.
----
*******************************************************
(Experimental) Apply torch.compile over FSDP, DDP, etc.
*******************************************************
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
However, if you are using DDP or FSDP with Fabric, the compilation won't incorporate the distributed calls inside these wrappers by default.
In an experimental feature, you can let ``fabric.setup()`` reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
In the future, this option will become the default.
.. code-block:: python
# Choose a distributed strategy like DDP or FSDP
fabric = L.Fabric(devices=2, strategy="ddp")
# Compile the model
model = torch.compile(model)
# Default: `fabric.setup()` will not reapply the compilation over DDP/FSDP
model = fabric.setup(model, _reapply_compile=False)
# Recompile the model over DDP/FSDP (experimental)
model = fabric.setup(model, _reapply_compile=True)
----
@ -296,4 +271,32 @@ On top of that, the compilation phase itself can be incredibly slow, taking seve
For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments.
Always compare the speed and memory usage of the compiled model against the original model!
----
*************************************
Using torch.compile with FSDP and DDP
*************************************
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
In the case of DDP and FSDP, ``fabric.setup()`` will automatically reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
This will ensure that the compilation can incorporate the distributed calls and optimize them.
However, should you have issues compiling DDP and FSDP models, you can opt out of this feature:
.. code-block:: python
# Choose a distributed strategy like DDP or FSDP
fabric = L.Fabric(devices=2, strategy="ddp")
# Compile the model
model = torch.compile(model)
# Default: `fabric.setup()` will configure compilation over DDP/FSDP for you
model = fabric.setup(model, _reapply_compile=True)
# Turn it off if you see issues with DDP/FSDP
model = fabric.setup(model, _reapply_compile=False)
|

View File

@ -214,7 +214,7 @@ class Fabric:
module: nn.Module,
*optimizers: Optimizer,
move_to_device: bool = True,
_reapply_compile: Optional[bool] = None,
_reapply_compile: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
r"""Set up a model and its optimizers for accelerated training.
@ -223,10 +223,11 @@ class Fabric:
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
issues.
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
@ -280,7 +281,7 @@ class Fabric:
return module
def setup_module(
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True
) -> _FabricModule:
r"""Set up a model for accelerated training or inference.
@ -292,11 +293,11 @@ class Fabric:
module: A :class:`torch.nn.Module` to set up
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
issues.
Returns:
The wrapped model.