lightning/examples/pl_domain_templates/imagenet.py

195 lines
6.9 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py.
Before you can run this example, you will need to download the ImageNet dataset manually from the
`official website <http://image-net.org/download>`_ and place it into a folder `path/to/imagenet`.
Train on ImageNet with default parameters:
.. code-block: bash
python imagenet.py fit --model.data_path /path/to/imagenet
or show all options you can change:
.. code-block: bash
python imagenet.py --help
python imagenet.py fit --help
"""
import os
from typing import Optional
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.strategies import ParallelStrategy
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
class ImageNetLightningModel(LightningModule):
"""
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
ImageNetLightningModel(
(model): ResNet(...)
)
"""
def __init__(
2021-01-30 10:17:12 +00:00
self,
data_path: str,
arch: str = "resnet18",
weights: Optional[str] = None,
2021-01-30 10:17:12 +00:00
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 256,
workers: int = 4,
):
super().__init__()
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
self.arch = arch
self.weights = weights
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.data_path = data_path
self.batch_size = batch_size
self.workers = workers
self.model = get_torchvision_model(self.arch, weights=self.weights)
self.train_dataset: Optional[Dataset] = None
self.eval_dataset: Optional[Dataset] = None
self.train_acc1 = Accuracy(top_k=1)
self.train_acc5 = Accuracy(top_k=5)
self.eval_acc1 = Accuracy(top_k=1)
self.eval_acc5 = Accuracy(top_k=5)
implement forward and update args (#709) (#724) * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * Fix failing GPU tests (#722) * Fix distributed_backend=None test We now throw a warning instead of an exception. Update test to reflect this. * Fix test_tube logger close when debug=True * Clean docs (#725) * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * updated gitignore * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * updated gitignore * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 * fix docs path * flake 8 * Update theme_variables.jinja * implement forward and update args (#709) Fixes the following issues as discussed in issue #709 1) Implement forward method wrapped. 2) Set default value for seed. "None" breaks tensorboard. 3) Update redundant hparams.data to new hparams.data_path. 4) Update 'use-16bit' to 'use_16bit' to maintain consistency. * use self.forward for val step (#709) Co-authored-by: Nic Eggert <nic@eggert.io> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-01-21 21:35:42 +00:00
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images, target = batch
output = self.model(images)
loss_train = F.cross_entropy(output, target)
self.log("train_loss", loss_train)
# update metrics
self.train_acc1(output, target)
self.train_acc5(output, target)
self.log("train_acc1", self.train_acc1, prog_bar=True)
self.log("train_acc5", self.train_acc5, prog_bar=True)
return loss_train
2021-10-29 12:05:05 +00:00
def eval_step(self, batch, batch_idx, prefix: str):
images, target = batch
output = self.model(images)
loss_val = F.cross_entropy(output, target)
self.log(f"{prefix}_loss", loss_val)
# update metrics
self.eval_acc1(output, target)
self.eval_acc5(output, target)
self.log(f"{prefix}_acc1", self.eval_acc1, prog_bar=True)
self.log(f"{prefix}_acc5", self.eval_acc5, prog_bar=True)
return loss_val
2021-10-29 12:05:05 +00:00
def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "test")
def configure_optimizers(self):
2021-01-30 10:17:12 +00:00
optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30))
return [optimizer], [scheduler]
def setup(self, stage: str):
if isinstance(self.trainer.strategy, ParallelStrategy):
# When using a single GPU per process and per `DistributedDataParallel`, we need to divide the batch size
# ourselves based on the total number of GPUs we have
num_processes = max(1, self.trainer.strategy.num_processes)
self.batch_size = int(self.batch_size / num_processes)
self.workers = int(self.workers / num_processes)
if stage == "fit":
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dir = os.path.join(self.data_path, "train")
self.train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
# all stages will use the eval dataset
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_dir = os.path.join(self.data_path, "val")
self.eval_dataset = datasets.ImageFolder(
val_dir,
transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]),
2021-01-30 10:17:12 +00:00
)
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.workers,
pin_memory=True,
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.eval_dataset, batch_size=self.batch_size, num_workers=self.workers, pin_memory=True
)
def test_dataloader(self):
return self.val_dataloader()
if __name__ == "__main__":
LightningCLI(
ImageNetLightningModel,
trainer_defaults={
"max_epochs": 90,
"accelerator": "auto",
"devices": 1,
"logger": False,
"benchmark": True,
"callbacks": [
# the PyTorch example refreshes every 10 batches
TQDMProgressBar(refresh_rate=10),
# save when the validation top1 accuracy improves
ModelCheckpoint(monitor="val_acc1", mode="max"),
],
},
seed_everything_default=42,
save_config_overwrite=True,
)