lightning/pl_examples/domain_templates/imagenet.py

239 lines
8.5 KiB
Python
Raw Normal View History

"""
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
import argparse
import os
import random
from collections import OrderedDict
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import pytorch_lightning as pl
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
from pytorch_lightning.core import LightningModule
2019-11-09 12:15:07 +00:00
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name for name in models.__dict__
2019-11-09 12:15:07 +00:00
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
)
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
class ImageNetLightningModel(LightningModule):
def __init__(self, hparams):
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
"""
TODO: add docstring here
"""
super().__init__()
self.hparams = hparams
self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained)
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
'acc1': acc1,
'acc5': acc5,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
def validation_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
output = OrderedDict({
'val_loss': loss_val,
'val_acc1': acc1,
'val_acc5': acc5,
})
return output
def validation_epoch_end(self, outputs):
tqdm_dict = {}
for metric_name in ["val_loss", "val_acc1", "val_acc5"]:
metric_total = 0
for output in outputs:
metric_value = output[metric_name]
# reduce manually when using dp
if self.trainer.use_dp or self.trainer.use_ddp2:
metric_value = torch.mean(metric_value)
metric_total += metric_value
tqdm_dict[metric_name] = metric_total / len(outputs)
result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]}
return result
@classmethod
def __accuracy(cls, output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def configure_optimizers(self):
optimizer = optim.SGD(
self.parameters(),
lr=self.hparams.lr,
momentum=self.hparams.momentum,
weight_decay=self.hparams.weight_decay
)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
return [optimizer], [scheduler]
def train_dataloader(self):
2019-11-09 12:15:07 +00:00
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
train_dir = os.path.join(self.hparams.data_path, 'train')
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
2019-11-09 12:15:07 +00:00
normalize,
]))
if self.use_ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=self.hparams.batch_size,
shuffle=(train_sampler is None),
num_workers=0,
sampler=train_sampler
)
return train_loader
def val_dataloader(self):
2019-11-09 12:15:07 +00:00
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
val_dir = os.path.join(self.hparams.data_path, 'val')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(val_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
2019-11-09 12:15:07 +00:00
normalize,
])),
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=0,
)
return val_loader
@staticmethod
def add_model_specific_args(parent_parser): # pragma: no-cover
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES,
help='model architecture: ' +
' | '.join(MODEL_NAMES) +
' (default: resnet18)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
parser.add_argument('--seed', type=int, default=42,
help='seed for initializing training. ')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
return parser
def get_args():
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
help='path to dataset')
parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str,
help='path to save output')
parent_parser.add_argument('--gpus', type=int, default=1,
help='how many gpus')
parent_parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2')
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
parent_parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
help='if true uses 16 bit precision')
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
return parser.parse_args()
def main(hparams):
model = ImageNetLightningModel(hparams)
if hparams.seed is not None:
random.seed(hparams.seed)
torch.manual_seed(hparams.seed)
cudnn.deterministic = True
trainer = pl.Trainer(
default_root_dir=hparams.save_path,
gpus=hparams.gpus,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,
precision=16 if hparams.use_16bit else 32,
)
if hparams.evaluate:
trainer.run_evaluation()
else:
trainer.fit(model)
if __name__ == '__main__':
main(get_args())