add doctests for example 2/n segmentation (#5083)

* draft

* fix

* drop folder

Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Jirka Borovec 2020-12-21 12:04:39 +01:00 committed by Jirka Borovec
parent 12d6437f65
commit 2438d7459b
2 changed files with 36 additions and 13 deletions

View File

@ -32,6 +32,19 @@ DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
for p_dir in (path_dir_images, path_dir_masks):
os.makedirs(p_dir, exist_ok=True)
for i in range(3):
path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png')
Image.new('RGB', image_dims).save(path_img)
path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png')
Image.new('L', image_dims).save(path_mask)
class KITTI(Dataset):
"""
Class for KITTI Semantic Segmentation Benchmark dataset
@ -53,6 +66,12 @@ class KITTI(Dataset):
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<...semantic_segmentation.KITTI object at ...>
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')
@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
It uses the FCN ResNet50 model as an example.
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
SegModel(
(net): UNet(
(layers): ModuleList(
(0): DoubleConv(...)
(1): Down(...)
(2): Down(...)
(3): Up(...)
(4): Up(...)
(5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
)
)
)
"""
def __init__(
self,

View File

@ -1,13 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.