lightning/pytorch_lightning/core/datamodule.py

367 lines
12 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union
import torch
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only
from torch.utils.data import DataLoader
class _DataModuleWrapper(type):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__has_added_checks = False
def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:
1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
3. Lets you check prepare_data and setup to see if they've been called
"""
if not cls.__has_added_checks:
cls.__has_added_checks = True
# Track prepare_data calls and make sure it runs on rank zero
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)
# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)
return obj
def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup have been called.
- When dm.prepare_data() is called, dm.has_prepared_data gets set to True
- When dm.setup('fit') is called, dm.has_setup_fit gets set to True
- When dm.setup('test') is called, dm.has_setup_test gets set to True
- When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
Args:
fn (function): Function that will be tracked to see if it has been called.
Returns:
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""
@functools.wraps(fn)
def wrapped_fn(*args, **kwargs):
# The object instance from which setup or prepare_data was called
obj = args[0]
# If calling setup, we check the stage and assign stage-specific bool args
if fn.__name__ == "setup":
# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
if stage == "fit" or stage is None:
obj._has_setup_fit = True
if stage == "test" or stage is None:
obj._has_setup_test = True
if fn.__name__ == "prepare_data":
obj._has_prepared_data = True
return fn(*args, **kwargs)
return wrapped_fn
class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapper):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Example::
class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
def setup(self):
# make assignments here (val/train/test split)
# called on every process in DDP
def train_dataloader(self):
train_split = Dataset(...)
return DataLoader(train_split)
def val_dataloader(self):
val_split = Dataset(...)
return DataLoader(val_split)
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)
A DataModule implements 5 key methods:
* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).
This allows you to share a full dataset without explaining how to download,
split transform and process the data
"""
name: str = ...
def __init__(
self,
train_transforms=None,
val_transforms=None,
test_transforms=None,
dims=None,
):
super().__init__()
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms
self._dims = dims if dims is not None else ()
# Pointer to the trainer object
self.trainer = None
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_test = False
@property
def train_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to train dataset
"""
return self._train_transforms
@train_transforms.setter
def train_transforms(self, t):
self._train_transforms = t
@property
def val_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to validation dataset
"""
return self._val_transforms
@val_transforms.setter
def val_transforms(self, t):
self._val_transforms = t
@property
def test_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to test dataset
"""
return self._test_transforms
@test_transforms.setter
def test_transforms(self, t):
self._test_transforms = t
@property
def dims(self):
"""
A tuple describing the shape of your data. Extra functionality exposed in ``size``.
"""
return self._dims
@dims.setter
def dims(self, d):
self._dims = d
def size(self, dim=None) -> Union[Tuple, int]:
"""
Return the dimension of each input either as a tuple or list of tuples. You can index this
just as you would with a torch tensor.
"""
if dim is not None:
return self.dims[dim]
return self.dims
@property
def has_prepared_data(self):
"""Return bool letting you know if datamodule.prepare_data() has been called or not.
Returns:
bool: True if datamodule.prepare_data() has been called. False by default.
"""
return self._has_prepared_data
@property
def has_setup_fit(self):
"""Return bool letting you know if datamodule.setup('fit') has been called or not.
Returns:
bool: True if datamodule.setup('fit') has been called. False by default.
"""
return self._has_setup_fit
@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Returns:
bool: True if datamodule.setup('test') has been called. False by default.
"""
return self._has_setup_test
@abstractmethod
def prepare_data(self, *args, **kwargs):
pass
@abstractmethod
def setup(self, stage: Optional[str] = None):
pass
@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
pass
@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass
@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass
@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
added_args = [x.dest for x in parser._actions]
blacklist = ["kwargs"]
depr_arg_names = blacklist + added_args
depr_arg_names = set(depr_arg_names)
allowed_types = (str, int, float, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at
for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
if arg_default == inspect._empty:
arg_default = None
parser.add_argument(
f"--{arg}",
dest=arg,
default=arg_default,
type=use_type,
help=f"autogenerated by plb.{cls.__name__}",
**arg_kwargs,
)
return parser
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`LightningDataModule`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid DataModule arguments.
Example::
parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)
"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)
# we only want to pass in valid DataModule args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
datamodule_kwargs = dict(
(name, params[name]) for name in valid_kwargs if name in params
)
datamodule_kwargs.update(**kwargs)
return cls(**datamodule_kwargs)
@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the DataModule signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
"""
datamodule_default_params = inspect.signature(cls.__init__).parameters
name_type_default = []
for arg in datamodule_default_params:
arg_type = datamodule_default_params[arg].annotation
arg_default = datamodule_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)
name_type_default.append((arg, arg_types, arg_default))
return name_type_default