diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 660834b489..5471b537c5 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -19,10 +19,9 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional, Tuple, Union import torch -from torch.utils.data import DataLoader - from pytorch_lightning.core.hooks import DataHooks from pytorch_lightning.utilities import parsing, rank_zero_only +from torch.utils.data import DataLoader class _DataModuleWrapper(type): @@ -67,20 +66,20 @@ def track_data_hook_calls(fn): obj = args[0] # If calling setup, we check the stage and assign stage-specific bool args - if fn.__name__ == 'setup': + if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit' and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() - stage = args[1] if len(args) > 1 else kwargs.get('stage', None) + stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - if stage == 'fit' or stage is None: + if stage == "fit" or stage is None: obj._has_setup_fit = True - if stage == 'test' or stage is None: + if stage == "test" or stage is None: obj._has_setup_test = True - if fn.__name__ == 'prepare_data': + if fn.__name__ == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -131,7 +130,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): name: str = ... def __init__( - self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None + self, + train_transforms=None, + val_transforms=None, + test_transforms=None, + dims=None, ): super().__init__() self._train_transforms = train_transforms @@ -254,10 +257,10 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: r"""Extends existing argparse by default `LightningDataModule` attributes. """ - parser = ArgumentParser(parents=[parent_parser], add_help=False,) + parser = ArgumentParser(parents=[parent_parser], add_help=False) added_args = [x.dest for x in parser._actions] - blacklist = ['kwargs'] + blacklist = ["kwargs"] depr_arg_names = blacklist + added_args depr_arg_names = set(depr_arg_names) @@ -265,7 +268,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): # TODO: get "help" from docstring :) for arg, arg_types, arg_default in ( - at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names + at + for at in cls.get_init_arguments_and_types() + if at[0] not in depr_arg_names ): arg_types = [at for at in allowed_types if at in arg_types] if not arg_types: @@ -290,11 +295,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): arg_default = None parser.add_argument( - f'--{arg}', + f"--{arg}", dest=arg, default=arg_default, type=use_type, - help=f'autogenerated by plb.{cls.__name__}', + help=f"autogenerated by plb.{cls.__name__}", **arg_kwargs, ) @@ -324,7 +329,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper): # we only want to pass in valid DataModule args, the rest may be user specific valid_kwargs = inspect.signature(cls.__init__).parameters - datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + datamodule_kwargs = dict( + (name, params[name]) for name in valid_kwargs if name in params + ) datamodule_kwargs.update(**kwargs) return cls(**datamodule_kwargs)