diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index c2175da807..ce2fbf6a12 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -34,9 +34,12 @@ class ImageNetLightningModel(pl.LightningModule): self.hparams = hparams self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) + def forward(self, x): + return self.model(x) + def training_step(self, batch, batch_idx): images, target = batch - output = self.model(images) + output = self.forward(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) @@ -59,7 +62,7 @@ class ImageNetLightningModel(pl.LightningModule): def validation_step(self, batch, batch_idx): images, target = batch - output = self.model(images) + output = self.forward(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) @@ -132,7 +135,7 @@ class ImageNetLightningModel(pl.LightningModule): std=[0.229, 0.224, 0.225], ) - train_dir = os.path.join(self.hparams.data, 'train') + train_dir = os.path.join(self.hparams.data_path, 'train') train_dataset = datasets.ImageFolder( train_dir, transforms.Compose([ @@ -162,7 +165,7 @@ class ImageNetLightningModel(pl.LightningModule): mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) - val_dir = os.path.join(self.hparams.data, 'val') + val_dir = os.path.join(self.hparams.data_path, 'val') val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(val_dir, transforms.Compose([ transforms.Resize(256), @@ -185,7 +188,7 @@ class ImageNetLightningModel(pl.LightningModule): ' (default: resnet18)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') - parser.add_argument('--seed', type=int, default=None, + parser.add_argument('--seed', type=int, default=42, help='seed for initializing training. ') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', @@ -214,7 +217,7 @@ def get_args(): help='how many gpus') parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), help='supports three options dp, ddp, ddp2') - parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true', + parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true', help='if true uses 16 bit precision') parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')