fix imagenet example

This commit is contained in:
William Falcon 2019-11-09 07:15:07 -05:00
parent d56750899f
commit c10ca47ab8
1 changed files with 13 additions and 8 deletions

View File

@ -22,14 +22,10 @@ import torchvision.datasets as datasets
import pytorch_lightning as pl
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
NORMALIZE = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
)
@ -133,6 +129,11 @@ class ImageNetLightningModel(pl.LightningModule):
@pl.data_loader
def train_dataloader(self):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
train_dir = os.path.join(self.hparams.data, 'train')
train_dataset = datasets.ImageFolder(
train_dir,
@ -140,7 +141,7 @@ class ImageNetLightningModel(pl.LightningModule):
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
NORMALIZE,
normalize,
]))
if self.use_ddp:
@ -159,13 +160,17 @@ class ImageNetLightningModel(pl.LightningModule):
@pl.data_loader
def val_dataloader(self):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
val_dir = os.path.join(self.hparams.data, 'val')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(val_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
NORMALIZE,
normalize,
])),
batch_size=self.hparams.batch_size,
shuffle=False,