* implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
f8d9f8f773
commit
eeb48ceb96
|
@ -34,9 +34,12 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
self.hparams = hparams
|
||||
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
output = self.model(images)
|
||||
output = self.forward(images)
|
||||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
|
@ -59,7 +62,7 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
output = self.model(images)
|
||||
output = self.forward(images)
|
||||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
|
@ -132,7 +135,7 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
|
||||
train_dir = os.path.join(self.hparams.data, 'train')
|
||||
train_dir = os.path.join(self.hparams.data_path, 'train')
|
||||
train_dataset = datasets.ImageFolder(
|
||||
train_dir,
|
||||
transforms.Compose([
|
||||
|
@ -162,7 +165,7 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
val_dir = os.path.join(self.hparams.data, 'val')
|
||||
val_dir = os.path.join(self.hparams.data_path, 'val')
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(val_dir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
|
@ -185,7 +188,7 @@ class ImageNetLightningModel(pl.LightningModule):
|
|||
' (default: resnet18)')
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--seed', type=int, default=None,
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N',
|
||||
|
@ -214,7 +217,7 @@ def get_args():
|
|||
help='how many gpus')
|
||||
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
|
||||
help='supports three options dp, ddp, ddp2')
|
||||
parent_parser.add_argument('--use-16bit', dest='use-16bit', action='store_true',
|
||||
parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
|
||||
help='if true uses 16 bit precision')
|
||||
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
|
|
Loading…
Reference in New Issue