Fix race condition when downloading data (#17732)

This commit is contained in:
Carlos Mocholí 2023-06-02 12:35:44 +02:00 committed by GitHub
parent 1f670a5cbd
commit 255b18823e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 12 deletions

View File

@ -337,6 +337,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL
+ import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
+ fabric = L.Fabric()
+ fabric.launch()
@ -346,9 +350,6 @@ Fabric is designed for the most complex models like foundation model scaling, LL
- model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
+ dataloader = fabric.setup_dataloaders(dataloader)
@ -375,6 +376,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL
import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
fabric = L.Fabric()
fabric.launch()
@ -382,9 +387,6 @@ model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

View File

@ -41,6 +41,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL
+ import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
+ fabric = L.Fabric()
+ fabric.launch()
@ -50,9 +54,6 @@ Fabric is designed for the most complex models like foundation model scaling, LL
- model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
+ dataloader = fabric.setup_dataloaders(dataloader)
@ -78,6 +79,10 @@ Fabric is designed for the most complex models like foundation model scaling, LL
import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
fabric = L.Fabric()
fabric.launch()
@ -85,9 +90,6 @@ model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)