add doctests for example 2/n segmentation (#5083)
* draft * fix * drop folder Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
12d6437f65
commit
2438d7459b
|
@ -32,6 +32,19 @@ DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
|
|||
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
|
||||
|
||||
|
||||
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
|
||||
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
|
||||
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
|
||||
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
|
||||
for p_dir in (path_dir_images, path_dir_masks):
|
||||
os.makedirs(p_dir, exist_ok=True)
|
||||
for i in range(3):
|
||||
path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png')
|
||||
Image.new('RGB', image_dims).save(path_img)
|
||||
path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png')
|
||||
Image.new('L', image_dims).save(path_mask)
|
||||
|
||||
|
||||
class KITTI(Dataset):
|
||||
"""
|
||||
Class for KITTI Semantic Segmentation Benchmark dataset
|
||||
|
@ -53,6 +66,12 @@ class KITTI(Dataset):
|
|||
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
|
||||
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
|
||||
(mask does not usually require transforms, but they can be implemented in a similar way).
|
||||
|
||||
>>> from pl_examples import DATASETS_PATH
|
||||
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
|
||||
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
|
||||
>>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
<...semantic_segmentation.KITTI object at ...>
|
||||
"""
|
||||
IMAGE_PATH = os.path.join('training', 'image_2')
|
||||
MASK_PATH = os.path.join('training', 'semantic')
|
||||
|
@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
|
|||
It uses the FCN ResNet50 model as an example.
|
||||
|
||||
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
|
||||
|
||||
>>> from pl_examples import DATASETS_PATH
|
||||
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
|
||||
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
|
||||
>>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
SegModel(
|
||||
(net): UNet(
|
||||
(layers): ModuleList(
|
||||
(0): DoubleConv(...)
|
||||
(1): Down(...)
|
||||
(2): Down(...)
|
||||
(3): Up(...)
|
||||
(4): Up(...)
|
||||
(5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
Loading…
Reference in New Issue