diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 2408ca9376..04f8fc2d2a 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -187,12 +187,10 @@ Here's an example using that uses ``wrap`` to create your model: class MyModel(pl.LightningModule): - def __init__(self): - super().__init__() + def configure_model(self): self.linear_layer = nn.Linear(32, 32) self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32)) - def configure_model(self): # modules are sharded across processes # as soon as they are wrapped with `wrap`. # During the forward/backward passes, weights get synced across processes