Set up dataloaders in k-fold cross validation example (#18582)

This commit is contained in:
Adrian Wälchli 2023-09-18 08:45:20 -07:00 committed by GitHub
parent fb7a0b539a
commit 01cd89fd04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -97,7 +97,7 @@ def validate_dataloader(model, data_loader, fabric, hparams, fold, acc_metric):
if hparams.dry_run:
break
# all_gather is used to aggregated the value across processes
# all_gather is used to aggregate the value across processes
loss = fabric.all_gather(loss).sum() / len(data_loader.dataset)
# compute acc
@ -118,7 +118,7 @@ def run(hparams):
# Let rank 0 download the data first, then everyone will load MNIST
with fabric.rank_zero_first():
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transform)
# Loop over different folds (shuffle = False by default so reproducible)
folds = hparams.folds
@ -147,6 +147,9 @@ def run(hparams):
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_ids))
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(val_ids))
# set up dataloaders to move data to the correct device
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
# get model and optimizer for the current fold
model, optimizer = models[fold], optimizers[fold]