diff --git a/.yapfignore b/.yapfignore index 17ed5ee527..b9f7d5cd47 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1,7 +1,44 @@ .git/* + # TODO -pytorch_lightning/* +pytorch_lightning/accelerators/legacy/* + + +# TODO +pytorch_lightning/callbacks/* + + +# TODO +pytorch_lightning/cluster_environments/* + + +# TODO +pytorch_lightning/core/* + + +# TODO +pytorch_lightning/loggers/* + + +# TODO +pytorch_lightning/metrics/* + +# TODO +pytorch_lightning/plugins/legacy/* + + +# TODO +pytorch_lightning/profiler/* + + +# TODO +pytorch_lightning/trainer/* + + +# TODO +pytorch_lightning/tuner/* + # TODO tests/* diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 8ba67aba8b..8bc70f03d3 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -74,9 +74,10 @@ class LightningDataParallel(DataParallel): for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - "on device {} (device_ids[0]) but found one of " - "them on device: {}".format(self.src_device_obj, t.device)) + raise RuntimeError( + f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])" + f" but found one of them on device: {t.device}" + ) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) @@ -127,6 +128,7 @@ class LightningDataParallel(DataParallel): r""" Override the gather method to support python scalars as well. """ + def gather_map(outputs): elem = outputs[0] elem_type = type(elem) @@ -140,8 +142,7 @@ class LightningDataParallel(DataParallel): if isinstance(elem, Mapping): if not all((len(elem) == len(d) for d in outputs)): raise ValueError('All dicts must have the same number of keys') - return elem_type(((k, gather_map([d[k] for d in outputs])) - for k in elem)) + return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem)) if isinstance(elem, Iterable) and not isinstance(elem, str): return elem_type(map(gather_map, zip(*outputs))) @@ -247,10 +248,10 @@ def warn_missing_output(fx_called): def parallel_apply( - modules: Module, - inputs: Tensor, - kwargs_tup: Optional[tuple] = None, - devices: Optional[list] = None, + modules: Module, + inputs: Tensor, + kwargs_tup: Optional[tuple] = None, + devices: Optional[list] = None, ): # pragma: no-cover r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) @@ -270,7 +271,7 @@ def parallel_apply( if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: - kwargs_tup = ({},) * len(modules) + kwargs_tup = ({}, ) * len(modules) if devices is not None: assert len(modules) == len(devices) else: @@ -288,7 +289,7 @@ def parallel_apply( with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): - input = (input,) + input = (input, ) module = module.to(device) @@ -333,10 +334,10 @@ def parallel_apply( m.testing = root_m.testing if len(modules) > 1: - threads = [threading.Thread(target=_worker, - args=(i, module, input, kwargs, device)) - for i, (module, input, kwargs, device) in - enumerate(zip(modules, inputs, kwargs_tup, devices))] + threads = [ + threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices)) + ] for thread in threads: thread.start() diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index f22466ae96..2c90688b92 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -31,7 +31,7 @@ _DEFAULT_BADGES = [ ] -def _load_requirements(path_dir: str , file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: +def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: """Load requirements from a file >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -72,8 +72,7 @@ def _load_readme_description(path_dir: str, homepage: str = __homepage__, versio # readthedocs badge text = text.replace('badge/?version=stable', f'badge/?version={version}') - text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', - f'pytorch-lightning.readthedocs.io/en/{version}') + text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{version}') # codecov badge text = text.replace('/branch/master/graph/badge.svg', f'/release/{version}/graph/badge.svg') # replace github badges for release ones diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 0fc23d4f80..2f7425bf3b 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -30,19 +30,15 @@ else: Batch = type(None) -def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): +def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): if device is None: - raise MisconfigurationException( - "device (torch.device) should be provided." - ) + raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) def from_numpy(value, device: torch.device = None): if device is None: - raise MisconfigurationException( - "device (torch.device) should be provided." - ) + raise MisconfigurationException("device (torch.device) should be provided.") return torch.from_numpy(value).to(device) @@ -55,8 +51,14 @@ CONVERSION_DTYPES = [ ] -def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, - wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any: +def apply_to_collection( + data: Any, + dtype: Union[type, tuple], + function: Callable, + *args, + wrong_dtype: Optional[Union[type, tuple]] = None, + **kwargs +) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -80,8 +82,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) - for k, v in data.items()}) + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()}) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) @@ -159,9 +160,7 @@ def move_data_to_device(batch: Any, device: torch.device): def convert_to_tensors(data, device: torch.device = None): if device is None: - raise MisconfigurationException( - "device (torch.device) should be provided." - ) + raise MisconfigurationException("device (torch.device) should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) return data diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 36a0739f1d..62626d1b5b 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -57,9 +57,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser - types_default = { - arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls) - } + types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)} modified_args = {} for k, v in vars(args).items(): @@ -130,7 +128,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: try: arg_types = tuple(arg_type.__args__) except AttributeError: - arg_types = (arg_type,) + arg_types = (arg_type, ) name_type_default.append((arg, arg_types, arg_default)) @@ -156,7 +154,10 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: >>> parser = Trainer.add_argparse_args(parser) >>> args = parser.parse_args([]) """ - parser = ArgumentParser(parents=[parent_parser], add_help=False,) + parser = ArgumentParser( + parents=[parent_parser], + add_help=False, + ) blacklist = ['kwargs'] depr_arg_names = cls.get_deprecated_arg_names() + blacklist @@ -164,9 +165,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: allowed_types = (str, int, float, bool) args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__) - for arg, arg_types, arg_default in ( - at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names - ): + for arg, arg_types, arg_default in (at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names): arg_types = [at for at in allowed_types if at in arg_types] if not arg_types: # skip argument with not supported type diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index e4a6fc5cd8..eb53579f94 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,6 +1,7 @@ from warnings import warn -warn("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", - DeprecationWarning) +warn( + "`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning +) from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index aee260e627..20f63e355e 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -32,8 +32,9 @@ def has_len(dataloader: DataLoader) -> bool: try: # try getting the length if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') + raise ValueError( + '`Dataloader` returned 0 length. Please make sure that your Dataloader at least returns 1 batch' + ) has_len = True except TypeError: has_len = False diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index c80f68213d..fc45b375c7 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -54,12 +54,12 @@ class InternalDebugger(object): self.dataloader_sequence_calls = [] def track_event( - self, - evt_type: str, - evt_value: Any = None, - global_rank: Optional[int] = None, - local_rank: Optional[int] = None, - comment: str = '' + self, + evt_type: str, + evt_value: Any = None, + global_rank: Optional[int] = None, + local_rank: Optional[int] = None, + comment: str = '' ) -> None: self.events.append({ "timestamp": time.time(), diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index b1bd62277a..fbed98ae2b 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -139,10 +139,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: all_available_gpus = _get_all_available_gpus() for gpu in gpus: if gpu not in all_available_gpus: - raise MisconfigurationException(f""" - You requested GPUs: {gpus} - But your machine only has: {all_available_gpus} - """) + raise MisconfigurationException( + f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}" + ) return gpus diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 23811e98bd..f283497e5e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -24,6 +24,7 @@ from pytorch_lightning import _logger as log if torch.distributed.is_available(): from torch.distributed import group, ReduceOp else: + class ReduceOp: SUM = None @@ -108,7 +109,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) def sync_ddp_if_available( - result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -127,7 +130,9 @@ def sync_ddp_if_available( def sync_ddp( - result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -162,13 +167,12 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): + @staticmethod def forward(ctx, tensor, group=group.WORLD): ctx.group = group - gathered_tensor = [ - torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(gathered_tensor, tensor, group=group) gathered_tensor = torch.stack(gathered_tensor, dim=0) @@ -179,12 +183,7 @@ class AllGatherGrad(torch.autograd.Function): def backward(ctx, *grad_output): grad_output = torch.cat(grad_output) - torch.distributed.all_reduce( - grad_output, - op=torch.distributed.ReduceOp.SUM, - async_op=False, - group=ctx.group - ) + torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) return grad_output[torch.distributed.get_rank()] diff --git a/pytorch_lightning/utilities/exceptions.py b/pytorch_lightning/utilities/exceptions.py index c45425bcbd..01b1e8c053 100644 --- a/pytorch_lightning/utilities/exceptions.py +++ b/pytorch_lightning/utilities/exceptions.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. + class MisconfigurationException(Exception): pass diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index a547261449..7fd5b287f7 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -1,6 +1,8 @@ from warnings import warn -warn("`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", - DeprecationWarning) +warn( + "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", + DeprecationWarning +) from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index e631f8715b..57568967ae 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -235,8 +235,10 @@ def lightning_getattr(model, attribute): elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): attr = getattr(trainer.datamodule, attribute) else: - raise ValueError(f'{attribute} is neither stored in the model namespace' - ' nor the `hparams` namespace/dict, nor the datamodule.') + raise ValueError( + f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,' + ' nor the datamodule.' + ) return attr @@ -246,8 +248,10 @@ def lightning_setattr(model, attribute, value): Will also set the attribute on datamodule, if it exists. """ if not lightning_hasattr(model, attribute): - raise ValueError(f'{attribute} is neither stored in the model namespace' - ' nor the `hparams` namespace/dict, nor the datamodule.') + raise ValueError( + f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,' + ' nor the datamodule.' + ) trainer = getattr(model, 'trainer', None) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 353112c186..da98e00b71 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Helper functions to help with reproducibility of models. """ import os diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py index 3ae0ada6f3..c520086f62 100644 --- a/pytorch_lightning/utilities/warning_utils.py +++ b/pytorch_lightning/utilities/warning_utils.py @@ -1,6 +1,7 @@ from warnings import warn -warn("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", - DeprecationWarning) +warn( + "`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning +) from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 717cebd68a..210047c466 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -34,6 +34,7 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover def pl_multi_process(func): + @functools.wraps(func) def wrapper(*args, **kwargs): queue = Queue() diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 1401140691..20b7a04e7b 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -1,6 +1,8 @@ from warnings import warn -warn("`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", - DeprecationWarning) +warn( + "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", + DeprecationWarning +) from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401