Black format pytorch_lightning/core/datamodule.py (#3574)
Split out changes from #3563 to make that PR easier to review
This commit is contained in:
parent
3442b97d1f
commit
c346679f81
|
@ -19,10 +19,9 @@ from argparse import ArgumentParser, Namespace
|
|||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.hooks import DataHooks
|
||||
from pytorch_lightning.utilities import parsing, rank_zero_only
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class _DataModuleWrapper(type):
|
||||
|
@ -67,20 +66,20 @@ def track_data_hook_calls(fn):
|
|||
obj = args[0]
|
||||
|
||||
# If calling setup, we check the stage and assign stage-specific bool args
|
||||
if fn.__name__ == 'setup':
|
||||
if fn.__name__ == "setup":
|
||||
|
||||
# Get stage either by grabbing from args or checking kwargs.
|
||||
# If not provided, set call status of 'fit' and 'test' to True.
|
||||
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
|
||||
stage = args[1] if len(args) > 1 else kwargs.get('stage', None)
|
||||
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
|
||||
|
||||
if stage == 'fit' or stage is None:
|
||||
if stage == "fit" or stage is None:
|
||||
obj._has_setup_fit = True
|
||||
|
||||
if stage == 'test' or stage is None:
|
||||
if stage == "test" or stage is None:
|
||||
obj._has_setup_test = True
|
||||
|
||||
if fn.__name__ == 'prepare_data':
|
||||
if fn.__name__ == "prepare_data":
|
||||
obj._has_prepared_data = True
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
@ -131,7 +130,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
name: str = ...
|
||||
|
||||
def __init__(
|
||||
self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None
|
||||
self,
|
||||
train_transforms=None,
|
||||
val_transforms=None,
|
||||
test_transforms=None,
|
||||
dims=None,
|
||||
):
|
||||
super().__init__()
|
||||
self._train_transforms = train_transforms
|
||||
|
@ -254,10 +257,10 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
r"""Extends existing argparse by default `LightningDataModule` attributes.
|
||||
"""
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
added_args = [x.dest for x in parser._actions]
|
||||
|
||||
blacklist = ['kwargs']
|
||||
blacklist = ["kwargs"]
|
||||
depr_arg_names = blacklist + added_args
|
||||
depr_arg_names = set(depr_arg_names)
|
||||
|
||||
|
@ -265,7 +268,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
|
||||
# TODO: get "help" from docstring :)
|
||||
for arg, arg_types, arg_default in (
|
||||
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
|
||||
at
|
||||
for at in cls.get_init_arguments_and_types()
|
||||
if at[0] not in depr_arg_names
|
||||
):
|
||||
arg_types = [at for at in allowed_types if at in arg_types]
|
||||
if not arg_types:
|
||||
|
@ -290,11 +295,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
arg_default = None
|
||||
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
f"--{arg}",
|
||||
dest=arg,
|
||||
default=arg_default,
|
||||
type=use_type,
|
||||
help=f'autogenerated by plb.{cls.__name__}',
|
||||
help=f"autogenerated by plb.{cls.__name__}",
|
||||
**arg_kwargs,
|
||||
)
|
||||
|
||||
|
@ -324,7 +329,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
|
||||
# we only want to pass in valid DataModule args, the rest may be user specific
|
||||
valid_kwargs = inspect.signature(cls.__init__).parameters
|
||||
datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
|
||||
datamodule_kwargs = dict(
|
||||
(name, params[name]) for name in valid_kwargs if name in params
|
||||
)
|
||||
datamodule_kwargs.update(**kwargs)
|
||||
|
||||
return cls(**datamodule_kwargs)
|
||||
|
|
Loading…
Reference in New Issue