diff --git a/docs/source-app/levels/basic/hello_components/pl_multinode.py b/docs/source-app/levels/basic/hello_components/pl_multinode.py index 0ba033e0d8..44db267160 100644 --- a/docs/source-app/levels/basic/hello_components/pl_multinode.py +++ b/docs/source-app/levels/basic/hello_components/pl_multinode.py @@ -1,6 +1,6 @@ # app.py import lightning as L -from lightning.app.components import PyTorchLightningMultiNode +from lightning.app.components import LightningTrainerMultiNode from lightning.pytorch.demos.boring_classes import BoringModel @@ -12,7 +12,7 @@ class LightningTrainerDistributed(L.LightningWork): trainer.fit(model) # 8 GPU: (2 nodes of 4 x v100) -component = PyTorchLightningMultiNode( +component = LightningTrainerMultiNode( LightningTrainerDistributed, num_nodes=2, cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x v100