fix imagenet example
This commit is contained in:
parent
d56750899f
commit
c10ca47ab8
|
@ -22,14 +22,10 @@ import torchvision.datasets as datasets
|
|||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
# pull out resnet names from torchvision models
|
||||
MODEL_NAMES = sorted(
|
||||
name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name])
|
||||
)
|
||||
NORMALIZE = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,6 +129,11 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
|
||||
train_dir = os.path.join(self.hparams.data, 'train')
|
||||
train_dataset = datasets.ImageFolder(
|
||||
train_dir,
|
||||
|
@ -140,7 +141,7 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
NORMALIZE,
|
||||
normalize,
|
||||
]))
|
||||
|
||||
if self.use_ddp:
|
||||
|
@ -159,13 +160,17 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
val_dir = os.path.join(self.hparams.data, 'val')
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(val_dir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
NORMALIZE,
|
||||
normalize,
|
||||
])),
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=False,
|
||||
|
|
Loading…
Reference in New Issue