diff --git a/README.md b/README.md index 6bc4f3ac2f..7d99eebec3 100644 --- a/README.md +++ b/README.md @@ -337,6 +337,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL + import lightning as L import torch; import torchvision as tv + dataset = tv.datasets.CIFAR10("data", download=True, + train=True, + transform=tv.transforms.ToTensor()) + + fabric = L.Fabric() + fabric.launch() @@ -346,9 +350,6 @@ Fabric is designed for the most complex models like foundation model scaling, LL - model.to(device) + model, optimizer = fabric.setup(model, optimizer) - dataset = tv.datasets.CIFAR10("data", download=True, - train=True, - transform=tv.transforms.ToTensor()) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + dataloader = fabric.setup_dataloaders(dataloader) @@ -375,6 +376,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL import lightning as L import torch; import torchvision as tv +dataset = tv.datasets.CIFAR10("data", download=True, + train=True, + transform=tv.transforms.ToTensor()) + fabric = L.Fabric() fabric.launch() @@ -382,9 +387,6 @@ model = tv.models.resnet18() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) model, optimizer = fabric.setup(model, optimizer) -dataset = tv.datasets.CIFAR10("data", download=True, - train=True, - transform=tv.transforms.ToTensor()) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) dataloader = fabric.setup_dataloaders(dataloader) diff --git a/src/lightning_fabric/README.md b/src/lightning_fabric/README.md index 7553ae87b4..78fff6beae 100644 --- a/src/lightning_fabric/README.md +++ b/src/lightning_fabric/README.md @@ -41,6 +41,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL + import lightning as L import torch; import torchvision as tv + dataset = tv.datasets.CIFAR10("data", download=True, + train=True, + transform=tv.transforms.ToTensor()) + + fabric = L.Fabric() + fabric.launch() @@ -50,9 +54,6 @@ Fabric is designed for the most complex models like foundation model scaling, LL - model.to(device) + model, optimizer = fabric.setup(model, optimizer) - dataset = tv.datasets.CIFAR10("data", download=True, - train=True, - transform=tv.transforms.ToTensor()) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) + dataloader = fabric.setup_dataloaders(dataloader) @@ -78,6 +79,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL import lightning as L import torch; import torchvision as tv +dataset = tv.datasets.CIFAR10("data", download=True, + train=True, + transform=tv.transforms.ToTensor()) + fabric = L.Fabric() fabric.launch() @@ -85,9 +90,6 @@ model = tv.models.resnet18() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) model, optimizer = fabric.setup(model, optimizer) -dataset = tv.datasets.CIFAR10("data", download=True, - train=True, - transform=tv.transforms.ToTensor()) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) dataloader = fabric.setup_dataloaders(dataloader)