implement forward and update args (#709) (#724)

* implement forward and update args (#709)

Fixes the following issues as discussed in issue #709

1) Implement forward method wrapped.
2) Set default value for seed. "None" breaks tensorboard.
3) Update redundant hparams.data to new hparams.data_path.
4) Update 'use-16bit' to 'use_16bit' to maintain consistency.

* Fix failing GPU tests (#722)

* Fix distributed_backend=None test

We now throw a warning instead of an exception. Update test
to reflect this.

* Fix test_tube logger close when debug=True

* Clean docs (#725)

* updated gitignore

* updated gitignore

* updated links in ninja file

* updated docs

* finished callbacks

* finished callbacks

* finished callbacks

* fixed left menu

* added callbacks to menu

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* updated gitignore

* updated links in ninja file

* updated docs

* finished callbacks

* finished callbacks

* finished callbacks

* fixed left menu

* added callbacks to menu

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* finished rebase

* making private  members

* making private  members

* making private  members

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* set auto dp if no backend

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* fixed lightning import

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* finished lightning module

* finished lightning module

* finished lightning module

* finished lightning module

* added callbacks

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* set auto dp if no backend

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* flake 8

* flake 8

* fix docs path

* updated gitignore

* updated gitignore

* updated links in ninja file

* updated docs

* finished callbacks

* finished callbacks

* finished callbacks

* fixed left menu

* added callbacks to menu

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* updated gitignore

* updated docs

* finished callbacks

* finished callbacks

* finished callbacks

* fixed left menu

* added callbacks to menu

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* added direct links to docs

* finished rebase

* making private  members

* making private  members

* making private  members

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* set auto dp if no backend

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* working on trainer docs

* fixed lightning import

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* cleared  spaces

* finished lightning module

* finished lightning module

* finished lightning module

* finished lightning module

* added callbacks

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* added loggers

* flake 8

* flake 8

* fix docs path

* flake 8

* Update theme_variables.jinja

* implement forward and update args (#709)

Fixes the following issues as discussed in issue #709

1) Implement forward method wrapped.
2) Set default value for seed. "None" breaks tensorboard.
3) Update redundant hparams.data to new hparams.data_path.
4) Update 'use-16bit' to 'use_16bit' to maintain consistency.

* use self.forward for val step (#709)

Co-authored-by: Nic Eggert <nic@eggert.io>
Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Harsh Sharma 2020-01-21 14:35:42 -07:00 committed by William Falcon
parent f8d9f8f773
commit eeb48ceb96
1 changed files with 9 additions and 6 deletions

View File

@ -34,9 +34,12 @@ class ImageNetLightningModel(pl.LightningModule):
self.hparams = hparams self.hparams = hparams
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images, target = batch images, target = batch
output = self.model(images) output = self.forward(images)
loss_val = F.cross_entropy(output, target) loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
@ -59,7 +62,7 @@ class ImageNetLightningModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images, target = batch images, target = batch
output = self.model(images) output = self.forward(images)
loss_val = F.cross_entropy(output, target) loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
@ -132,7 +135,7 @@ class ImageNetLightningModel(pl.LightningModule):
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
) )
train_dir = os.path.join(self.hparams.data, 'train') train_dir = os.path.join(self.hparams.data_path, 'train')
train_dataset = datasets.ImageFolder( train_dataset = datasets.ImageFolder(
train_dir, train_dir,
transforms.Compose([ transforms.Compose([
@ -162,7 +165,7 @@ class ImageNetLightningModel(pl.LightningModule):
mean=[0.485, 0.456, 0.406], mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
) )
val_dir = os.path.join(self.hparams.data, 'val') val_dir = os.path.join(self.hparams.data_path, 'val')
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(val_dir, transforms.Compose([ datasets.ImageFolder(val_dir, transforms.Compose([
transforms.Resize(256), transforms.Resize(256),
@ -185,7 +188,7 @@ class ImageNetLightningModel(pl.LightningModule):
' (default: resnet18)') ' (default: resnet18)')
parser.add_argument('--epochs', default=90, type=int, metavar='N', parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run') help='number of total epochs to run')
parser.add_argument('--seed', type=int, default=None, parser.add_argument('--seed', type=int, default=42,
help='seed for initializing training. ') help='seed for initializing training. ')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', metavar='N',
@ -214,7 +217,7 @@ def get_args():
help='how many gpus') help='how many gpus')
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2') help='supports three options dp, ddp, ddp2')
parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true', parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
help='if true uses 16 bit precision') help='if true uses 16 bit precision')
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set') help='evaluate model on validation set')