Set up dataloaders in k-fold cross validation example (#18582)
This commit is contained in:
parent
fb7a0b539a
commit
01cd89fd04
|
@ -97,7 +97,7 @@ def validate_dataloader(model, data_loader, fabric, hparams, fold, acc_metric):
|
|||
if hparams.dry_run:
|
||||
break
|
||||
|
||||
# all_gather is used to aggregated the value across processes
|
||||
# all_gather is used to aggregate the value across processes
|
||||
loss = fabric.all_gather(loss).sum() / len(data_loader.dataset)
|
||||
|
||||
# compute acc
|
||||
|
@ -118,7 +118,7 @@ def run(hparams):
|
|||
|
||||
# Let rank 0 download the data first, then everyone will load MNIST
|
||||
with fabric.rank_zero_first():
|
||||
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
|
||||
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transform)
|
||||
|
||||
# Loop over different folds (shuffle = False by default so reproducible)
|
||||
folds = hparams.folds
|
||||
|
@ -147,6 +147,9 @@ def run(hparams):
|
|||
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_ids))
|
||||
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_ids))
|
||||
|
||||
# set up dataloaders to move data to the correct device
|
||||
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
|
||||
|
||||
# get model and optimizer for the current fold
|
||||
model, optimizer = models[fold], optimizers[fold]
|
||||
|
||||
|
|
Loading…
Reference in New Issue