diff --git a/docs/source-fabric/advanced/model_init.rst b/docs/source-fabric/advanced/model_init.rst index e3098083ed..f1e5cf846b 100644 --- a/docs/source-fabric/advanced/model_init.rst +++ b/docs/source-fabric/advanced/model_init.rst @@ -75,6 +75,10 @@ When training sharded models with :doc:`FSDP ` or DeepSpeed model = fabric.setup(model) # parameters get sharded and initialized at once + # Make sure to create the optimizer only after the model has been set up + optimizer = torch.optim.Adam(model.parameters()) + optimizer = fabric.setup_optimizers(optimizer) + .. note:: Empty-init is experimental and the behavior may change in the future. For FSDP on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).