Black format pytorch_lightning/core/datamodule.py (#3574)

Split out changes from #3563 to make that PR easier to review
This commit is contained in:
ananthsub 2020-09-20 20:00:00 -07:00 committed by GitHub
parent 3442b97d1f
commit c346679f81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 14 deletions

View File

@ -19,10 +19,9 @@ from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union
import torch
from torch.utils.data import DataLoader
from pytorch_lightning.core.hooks import DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only
from torch.utils.data import DataLoader
class _DataModuleWrapper(type):
@ -67,20 +66,20 @@ def track_data_hook_calls(fn):
obj = args[0]
# If calling setup, we check the stage and assign stage-specific bool args
if fn.__name__ == 'setup':
if fn.__name__ == "setup":
# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get('stage', None)
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
if stage == 'fit' or stage is None:
if stage == "fit" or stage is None:
obj._has_setup_fit = True
if stage == 'test' or stage is None:
if stage == "test" or stage is None:
obj._has_setup_test = True
if fn.__name__ == 'prepare_data':
if fn.__name__ == "prepare_data":
obj._has_prepared_data = True
return fn(*args, **kwargs)
@ -131,7 +130,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
name: str = ...
def __init__(
self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None
self,
train_transforms=None,
val_transforms=None,
test_transforms=None,
dims=None,
):
super().__init__()
self._train_transforms = train_transforms
@ -254,10 +257,10 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
parser = ArgumentParser(parents=[parent_parser], add_help=False)
added_args = [x.dest for x in parser._actions]
blacklist = ['kwargs']
blacklist = ["kwargs"]
depr_arg_names = blacklist + added_args
depr_arg_names = set(depr_arg_names)
@ -265,7 +268,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
at
for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
@ -290,11 +295,11 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
arg_default = None
parser.add_argument(
f'--{arg}',
f"--{arg}",
dest=arg,
default=arg_default,
type=use_type,
help=f'autogenerated by plb.{cls.__name__}',
help=f"autogenerated by plb.{cls.__name__}",
**arg_kwargs,
)
@ -324,7 +329,9 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
# we only want to pass in valid DataModule args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
datamodule_kwargs = dict(
(name, params[name]) for name in valid_kwargs if name in params
)
datamodule_kwargs.update(**kwargs)
return cls(**datamodule_kwargs)