Reapply compile in `Fabric.setup()` by default (#19382)
This commit is contained in:
parent
af7e79a84a
commit
89ff87def0
|
@ -220,6 +220,7 @@ On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically a
|
|||
|
||||
Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
|
@ -255,32 +256,6 @@ Naturally, the tradoff here is that it will consume a bit more memory.
|
|||
|
||||
You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_.
|
||||
|
||||
----
|
||||
|
||||
|
||||
*******************************************************
|
||||
(Experimental) Apply torch.compile over FSDP, DDP, etc.
|
||||
*******************************************************
|
||||
|
||||
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
|
||||
However, if you are using DDP or FSDP with Fabric, the compilation won't incorporate the distributed calls inside these wrappers by default.
|
||||
In an experimental feature, you can let ``fabric.setup()`` reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
|
||||
In the future, this option will become the default.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Choose a distributed strategy like DDP or FSDP
|
||||
fabric = L.Fabric(devices=2, strategy="ddp")
|
||||
|
||||
# Compile the model
|
||||
model = torch.compile(model)
|
||||
|
||||
# Default: `fabric.setup()` will not reapply the compilation over DDP/FSDP
|
||||
model = fabric.setup(model, _reapply_compile=False)
|
||||
|
||||
# Recompile the model over DDP/FSDP (experimental)
|
||||
model = fabric.setup(model, _reapply_compile=True)
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
@ -296,4 +271,32 @@ On top of that, the compilation phase itself can be incredibly slow, taking seve
|
|||
For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments.
|
||||
Always compare the speed and memory usage of the compiled model against the original model!
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*************************************
|
||||
Using torch.compile with FSDP and DDP
|
||||
*************************************
|
||||
|
||||
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
|
||||
In the case of DDP and FSDP, ``fabric.setup()`` will automatically reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
|
||||
This will ensure that the compilation can incorporate the distributed calls and optimize them.
|
||||
However, should you have issues compiling DDP and FSDP models, you can opt out of this feature:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Choose a distributed strategy like DDP or FSDP
|
||||
fabric = L.Fabric(devices=2, strategy="ddp")
|
||||
|
||||
# Compile the model
|
||||
model = torch.compile(model)
|
||||
|
||||
# Default: `fabric.setup()` will configure compilation over DDP/FSDP for you
|
||||
model = fabric.setup(model, _reapply_compile=True)
|
||||
|
||||
# Turn it off if you see issues with DDP/FSDP
|
||||
model = fabric.setup(model, _reapply_compile=False)
|
||||
|
||||
|
||||
|
|
||||
|
|
|
@ -214,7 +214,7 @@ class Fabric:
|
|||
module: nn.Module,
|
||||
*optimizers: Optimizer,
|
||||
move_to_device: bool = True,
|
||||
_reapply_compile: Optional[bool] = None,
|
||||
_reapply_compile: bool = True,
|
||||
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
|
||||
r"""Set up a model and its optimizers for accelerated training.
|
||||
|
||||
|
@ -223,10 +223,11 @@ class Fabric:
|
|||
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
|
||||
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
|
||||
and alternatively use :meth:`to_device` manually.
|
||||
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
|
||||
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
|
||||
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
|
||||
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
|
||||
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
|
||||
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
|
||||
issues.
|
||||
|
||||
Returns:
|
||||
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
|
||||
|
@ -280,7 +281,7 @@ class Fabric:
|
|||
return module
|
||||
|
||||
def setup_module(
|
||||
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
|
||||
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True
|
||||
) -> _FabricModule:
|
||||
r"""Set up a model for accelerated training or inference.
|
||||
|
||||
|
@ -292,11 +293,11 @@ class Fabric:
|
|||
module: A :class:`torch.nn.Module` to set up
|
||||
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
|
||||
and alternatively use :meth:`to_device` manually.
|
||||
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
|
||||
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
|
||||
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
|
||||
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
|
||||
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
|
||||
|
||||
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
|
||||
issues.
|
||||
Returns:
|
||||
The wrapped model.
|
||||
|
||||
|
|
Loading…
Reference in New Issue