formatting to PL utils (#5713)

* yapf pl base

* over

* dist

* utils

* Apply suggestions from code review

* flake8

* neew way
This commit is contained in:
Jirka Borovec 2021-01-30 15:28:59 +01:00 committed by GitHub
parent 21d313edc5
commit eee38d59e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 122 additions and 77 deletions

View File

@ -1,7 +1,44 @@
.git/*
# TODO
pytorch_lightning/*
pytorch_lightning/accelerators/legacy/*
# TODO
pytorch_lightning/callbacks/*
# TODO
pytorch_lightning/cluster_environments/*
# TODO
pytorch_lightning/core/*
# TODO
pytorch_lightning/loggers/*
# TODO
pytorch_lightning/metrics/*
# TODO
pytorch_lightning/plugins/legacy/*
# TODO
pytorch_lightning/profiler/*
# TODO
pytorch_lightning/trainer/*
# TODO
pytorch_lightning/tuner/*
# TODO
tests/*

View File

@ -74,9 +74,10 @@ class LightningDataParallel(DataParallel):
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
raise RuntimeError(
f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])"
f" but found one of them on device: {t.device}"
)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
@ -127,6 +128,7 @@ class LightningDataParallel(DataParallel):
r"""
Override the gather method to support python scalars as well.
"""
def gather_map(outputs):
elem = outputs[0]
elem_type = type(elem)
@ -140,8 +142,7 @@ class LightningDataParallel(DataParallel):
if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs]))
for k in elem))
return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem))
if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))
@ -247,10 +248,10 @@ def warn_missing_output(fx_called):
def parallel_apply(
modules: Module,
inputs: Tensor,
kwargs_tup: Optional[tuple] = None,
devices: Optional[list] = None,
modules: Module,
inputs: Tensor,
kwargs_tup: Optional[tuple] = None,
devices: Optional[list] = None,
): # pragma: no-cover
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
@ -270,7 +271,7 @@ def parallel_apply(
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
kwargs_tup = ({}, ) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
@ -288,7 +289,7 @@ def parallel_apply(
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
input = (input, )
module = module.to(device)
@ -333,10 +334,10 @@ def parallel_apply(
m.testing = root_m.testing
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
threads = [
threading.Thread(target=_worker, args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
]
for thread in threads:
thread.start()

View File

@ -31,7 +31,7 @@ _DEFAULT_BADGES = [
]
def _load_requirements(path_dir: str , file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
"""Load requirements from a file
>>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
@ -72,8 +72,7 @@ def _load_readme_description(path_dir: str, homepage: str = __homepage__, versio
# readthedocs badge
text = text.replace('badge/?version=stable', f'badge/?version={version}')
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/',
f'pytorch-lightning.readthedocs.io/en/{version}')
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{version}')
# codecov badge
text = text.replace('/branch/master/graph/badge.svg', f'/release/{version}/graph/badge.svg')
# replace github badges for release ones

View File

@ -30,19 +30,15 @@ else:
Batch = type(None)
def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None):
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)
def from_numpy(value, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device)
@ -55,8 +51,14 @@ CONVERSION_DTYPES = [
]
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
**kwargs
) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
@ -80,8 +82,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
@ -159,9 +160,7 @@ def move_data_to_device(batch: Any, device: torch.device):
def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
raise MisconfigurationException("device (torch.device) should be provided.")
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))
return data

View File

@ -57,9 +57,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
types_default = {
arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)
}
types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)}
modified_args = {}
for k, v in vars(args).items():
@ -130,7 +128,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)
arg_types = (arg_type, )
name_type_default.append((arg, arg_types, arg_default))
@ -156,7 +154,10 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
parser = ArgumentParser(
parents=[parent_parser],
add_help=False,
)
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
@ -164,9 +165,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
allowed_types = (str, int, float, bool)
args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__)
for arg, arg_types, arg_default in (
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
for arg, arg_types, arg_default in (at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type

View File

@ -1,6 +1,7 @@
from warnings import warn
warn("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4",
DeprecationWarning)
warn(
"`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning
)
from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401

View File

@ -32,8 +32,9 @@ def has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
raise ValueError(
'`Dataloader` returned 0 length. Please make sure that your Dataloader at least returns 1 batch'
)
has_len = True
except TypeError:
has_len = False

View File

@ -54,12 +54,12 @@ class InternalDebugger(object):
self.dataloader_sequence_calls = []
def track_event(
self,
evt_type: str,
evt_value: Any = None,
global_rank: Optional[int] = None,
local_rank: Optional[int] = None,
comment: str = ''
self,
evt_type: str,
evt_value: Any = None,
global_rank: Optional[int] = None,
local_rank: Optional[int] = None,
comment: str = ''
) -> None:
self.events.append({
"timestamp": time.time(),

View File

@ -139,10 +139,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]:
all_available_gpus = _get_all_available_gpus()
for gpu in gpus:
if gpu not in all_available_gpus:
raise MisconfigurationException(f"""
You requested GPUs: {gpus}
But your machine only has: {all_available_gpus}
""")
raise MisconfigurationException(
f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}"
)
return gpus

View File

@ -24,6 +24,7 @@ from pytorch_lightning import _logger as log
if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
else:
class ReduceOp:
SUM = None
@ -108,7 +109,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None)
def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce a tensor across worker processes during distributed training
@ -127,7 +130,9 @@ def sync_ddp_if_available(
def sync_ddp(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
@ -162,13 +167,12 @@ def sync_ddp(
class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group=group.WORLD):
ctx.group = group
gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
@ -179,12 +183,7 @@ class AllGatherGrad(torch.autograd.Function):
def backward(ctx, *grad_output):
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(
grad_output,
op=torch.distributed.ReduceOp.SUM,
async_op=False,
group=ctx.group
)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()]

View File

@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class MisconfigurationException(Exception):
pass

View File

@ -1,6 +1,8 @@
from warnings import warn
warn("`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4",
DeprecationWarning)
warn(
"`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4",
DeprecationWarning
)
from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401

View File

@ -235,8 +235,10 @@ def lightning_getattr(model, attribute):
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
raise ValueError(
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
' nor the datamodule.'
)
return attr
@ -246,8 +248,10 @@ def lightning_setattr(model, attribute, value):
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
raise ValueError(
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
' nor the datamodule.'
)
trainer = getattr(model, 'trainer', None)

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions to help with reproducibility of models. """
import os

View File

@ -1,6 +1,7 @@
from warnings import warn
warn("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4",
DeprecationWarning)
warn(
"`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning
)
from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401

View File

@ -34,6 +34,7 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover
def pl_multi_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
queue = Queue()

View File

@ -1,6 +1,8 @@
from warnings import warn
warn("`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4",
DeprecationWarning)
warn(
"`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4",
DeprecationWarning
)
from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401