lightning/pl_examples/domain_templates/semantic_segmentation.py

238 lines
9.1 KiB
Python
Raw Normal View History

import os
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
2020-03-12 16:41:37 +00:00
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import random
import pytorch_lightning as pl
from pl_examples.models.unet import UNet
from pytorch_lightning.loggers import WandbLogger
DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
class KITTI(Dataset):
"""
Class for KITTI Semantic Segmentation Benchmark dataset
Dataset link - http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
There are 34 classes in the given labels. However, not all of them are useful for training
(like railings on highways, road dividers, etc.).
So, these useless classes (the pixel values of these classes) are stored in the `void_labels`.
The useful classes are stored in the `valid_labels`.
The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
the loss function when comparing with the output.
The `get_filenames` function retrieves the filenames of all images in the given `path` and
saves the absolute path in a list.
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')
Custom argparser extension with Trainer arguments (argument types added) (#1147) * `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-03-24 18:55:27 +00:00
def __init__(
self,
data_path: str,
split: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None
):
self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.transform = transform
self.split = split
self.data_path = data_path
self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)
# Split between train and valid set (80/20)
random_inst = random.Random(12345) # for repeatability
n_items = len(self.img_list)
idxs = random_inst.sample(range(n_items), n_items // 5)
if self.split == 'train':
idxs = [idx for idx in range(n_items) if idx not in idxs]
self.img_list = [self.img_list[i] for i in idxs]
self.mask_list = [self.mask_list[i] for i in idxs]
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)
if self.transform:
img = self.transform(img)
return img, mask
def encode_segmap(self, mask):
"""
Sets void classes to zero so they won't be considered for training
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask
def get_filenames(self, path):
"""
Returns a list of absolute paths to images inside given `path`
"""
files_list = list()
for filename in os.listdir(path):
files_list.append(os.path.join(path, filename))
return files_list
class SegModel(pl.LightningModule):
"""
Semantic Segmentation Module
This is a basic semantic segmentation module implemented with Lightning.
It uses CrossEntropyLoss as the default loss function. May be replaced with
other loss functions as required.
It is specific to KITTI dataset i.e. dataloaders are for KITTI
and Normalize transform uses the mean and standard deviation of this dataset.
It uses the FCN ResNet50 model as an example.
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
"""
Custom argparser extension with Trainer arguments (argument types added) (#1147) * `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-03-24 18:55:27 +00:00
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.data_path = hparams.data_path
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
features_start=hparams.features_start, bilinear=hparams.bilinear)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_nb):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
log_dict = {'train_loss': loss_val}
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}
def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'val_loss': loss_val}
def validation_epoch_end(self, outputs):
loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
log_dict = {'val_loss': loss_val}
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}
def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)
def main(hparams):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = SegModel(hparams)
# ------------------------
# 2 SET LOGGER
# ------------------------
logger = False
if hparams.log_wandb:
logger = WandbLogger()
# optional: log model topology
logger.watch(model.net)
# ------------------------
# 3 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
gpus=hparams.gpus,
logger=logger,
max_epochs=hparams.epochs,
accumulate_grad_batches=hparams.grad_batches,
distributed_backend=hparams.distributed_backend,
precision=16 if hparams.use_amp else 32,
)
# ------------------------
# 5 START TRAINING
# ------------------------
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2')
parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision')
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", action='store_true', default=False,
help="whether to use bilinear interpolation or transposed")
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")
hparams = parser.parse_args()
main(hparams)