diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index 3e45abd073..3dbe4a600a 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -22,14 +22,10 @@ import torchvision.datasets as datasets import pytorch_lightning as pl +# pull out resnet names from torchvision models MODEL_NAMES = sorted( name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name]) -) -NORMALIZE = transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], + if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) @@ -133,6 +129,11 @@ class ImageNetLightningModel(pl.LightningModule): @pl.data_loader def train_dataloader(self): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + train_dir = os.path.join(self.hparams.data, 'train') train_dataset = datasets.ImageFolder( train_dir, @@ -140,7 +141,7 @@ class ImageNetLightningModel(pl.LightningModule): transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - NORMALIZE, + normalize, ])) if self.use_ddp: @@ -159,13 +160,17 @@ class ImageNetLightningModel(pl.LightningModule): @pl.data_loader def val_dataloader(self): + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) val_dir = os.path.join(self.hparams.data, 'val') val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(val_dir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), - NORMALIZE, + normalize, ])), batch_size=self.hparams.batch_size, shuffle=False,