Update deepspeed model-parallel docs (#18091)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-07-17 18:02:54 +02:00 committed by GitHub
parent b8d4a70db7
commit d79eaae334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -187,12 +187,10 @@ Here's an example using that uses ``wrap`` to create your model:
class MyModel(pl.LightningModule): class MyModel(pl.LightningModule):
def __init__(self): def configure_model(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32) self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32)) self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32))
def configure_model(self):
# modules are sharded across processes # modules are sharded across processes
# as soon as they are wrapped with `wrap`. # as soon as they are wrapped with `wrap`.
# During the forward/backward passes, weights get synced across processes # During the forward/backward passes, weights get synced across processes