From 2438d7459b108a4eda127cb6915ec170fd35044a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Dec 2020 12:04:39 +0100 Subject: [PATCH] add doctests for example 2/n segmentation (#5083) * draft * fix * drop folder Co-authored-by: chaton --- .../domain_templates/semantic_segmentation.py | 36 +++++++++++++++++++ pl_examples/pytorch_ecosystem/__init__.py | 13 ------- 2 files changed, 36 insertions(+), 13 deletions(-) delete mode 100644 pl_examples/pytorch_ecosystem/__init__.py diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 7bcad597a9..2e718a37ac 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -32,6 +32,19 @@ DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) +def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)): + """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded.""" + path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH) + path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH) + for p_dir in (path_dir_images, path_dir_masks): + os.makedirs(p_dir, exist_ok=True) + for i in range(3): + path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png') + Image.new('RGB', image_dims).save(path_img) + path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png') + Image.new('L', image_dims).save(path_mask) + + class KITTI(Dataset): """ Class for KITTI Semantic Segmentation Benchmark dataset @@ -53,6 +66,12 @@ class KITTI(Dataset): In the `get_item` function, images and masks are resized to the given `img_size`, masks are encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). + + >>> from pl_examples import DATASETS_PATH + >>> dataset_path = os.path.join(DATASETS_PATH, "Kitti") + >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512)) + >>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + <...semantic_segmentation.KITTI object at ...> """ IMAGE_PATH = os.path.join('training', 'image_2') MASK_PATH = os.path.join('training', 'semantic') @@ -141,6 +160,23 @@ class SegModel(pl.LightningModule): It uses the FCN ResNet50 model as an example. Adam optimizer is used along with Cosine Annealing learning rate scheduler. + + >>> from pl_examples import DATASETS_PATH + >>> dataset_path = os.path.join(DATASETS_PATH, "Kitti") + >>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512)) + >>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + SegModel( + (net): UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + ) """ def __init__( self, diff --git a/pl_examples/pytorch_ecosystem/__init__.py b/pl_examples/pytorch_ecosystem/__init__.py deleted file mode 100644 index d7aa17d7f8..0000000000 --- a/pl_examples/pytorch_ecosystem/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License.