formatting to PL utils (#5713)
* yapf pl base * over * dist * utils * Apply suggestions from code review * flake8 * neew way
This commit is contained in:
parent
21d313edc5
commit
eee38d59e7
39
.yapfignore
39
.yapfignore
|
@ -1,7 +1,44 @@
|
|||
.git/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/*
|
||||
pytorch_lightning/accelerators/legacy/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/callbacks/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/cluster_environments/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/core/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/loggers/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/metrics/*
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/plugins/legacy/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/profiler/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/trainer/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/tuner/*
|
||||
|
||||
|
||||
# TODO
|
||||
tests/*
|
||||
|
|
|
@ -74,9 +74,10 @@ class LightningDataParallel(DataParallel):
|
|||
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
raise RuntimeError("module must have its parameters and buffers "
|
||||
"on device {} (device_ids[0]) but found one of "
|
||||
"them on device: {}".format(self.src_device_obj, t.device))
|
||||
raise RuntimeError(
|
||||
f"module must have its parameters and buffers on device {self.src_device_obj} (device_ids[0])"
|
||||
f" but found one of them on device: {t.device}"
|
||||
)
|
||||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
|
||||
|
@ -127,6 +128,7 @@ class LightningDataParallel(DataParallel):
|
|||
r"""
|
||||
Override the gather method to support python scalars as well.
|
||||
"""
|
||||
|
||||
def gather_map(outputs):
|
||||
elem = outputs[0]
|
||||
elem_type = type(elem)
|
||||
|
@ -140,8 +142,7 @@ class LightningDataParallel(DataParallel):
|
|||
if isinstance(elem, Mapping):
|
||||
if not all((len(elem) == len(d) for d in outputs)):
|
||||
raise ValueError('All dicts must have the same number of keys')
|
||||
return elem_type(((k, gather_map([d[k] for d in outputs]))
|
||||
for k in elem))
|
||||
return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem))
|
||||
|
||||
if isinstance(elem, Iterable) and not isinstance(elem, str):
|
||||
return elem_type(map(gather_map, zip(*outputs)))
|
||||
|
@ -247,10 +248,10 @@ def warn_missing_output(fx_called):
|
|||
|
||||
|
||||
def parallel_apply(
|
||||
modules: Module,
|
||||
inputs: Tensor,
|
||||
kwargs_tup: Optional[tuple] = None,
|
||||
devices: Optional[list] = None,
|
||||
modules: Module,
|
||||
inputs: Tensor,
|
||||
kwargs_tup: Optional[tuple] = None,
|
||||
devices: Optional[list] = None,
|
||||
): # pragma: no-cover
|
||||
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
||||
|
@ -270,7 +271,7 @@ def parallel_apply(
|
|||
if kwargs_tup is not None:
|
||||
assert len(modules) == len(kwargs_tup)
|
||||
else:
|
||||
kwargs_tup = ({},) * len(modules)
|
||||
kwargs_tup = ({}, ) * len(modules)
|
||||
if devices is not None:
|
||||
assert len(modules) == len(devices)
|
||||
else:
|
||||
|
@ -288,7 +289,7 @@ def parallel_apply(
|
|||
with torch.cuda.device(device):
|
||||
# this also avoids accidental slicing of `input` if it is a Tensor
|
||||
if not isinstance(input, (list, tuple)):
|
||||
input = (input,)
|
||||
input = (input, )
|
||||
|
||||
module = module.to(device)
|
||||
|
||||
|
@ -333,10 +334,10 @@ def parallel_apply(
|
|||
m.testing = root_m.testing
|
||||
|
||||
if len(modules) > 1:
|
||||
threads = [threading.Thread(target=_worker,
|
||||
args=(i, module, input, kwargs, device))
|
||||
for i, (module, input, kwargs, device) in
|
||||
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
||||
threads = [
|
||||
threading.Thread(target=_worker, args=(i, module, input, kwargs, device))
|
||||
for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
|
|
@ -31,7 +31,7 @@ _DEFAULT_BADGES = [
|
|||
]
|
||||
|
||||
|
||||
def _load_requirements(path_dir: str , file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
|
||||
def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
|
||||
"""Load requirements from a file
|
||||
|
||||
>>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
|
@ -72,8 +72,7 @@ def _load_readme_description(path_dir: str, homepage: str = __homepage__, versio
|
|||
|
||||
# readthedocs badge
|
||||
text = text.replace('badge/?version=stable', f'badge/?version={version}')
|
||||
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/',
|
||||
f'pytorch-lightning.readthedocs.io/en/{version}')
|
||||
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{version}')
|
||||
# codecov badge
|
||||
text = text.replace('/branch/master/graph/badge.svg', f'/release/{version}/graph/badge.svg')
|
||||
# replace github badges for release ones
|
||||
|
|
|
@ -30,19 +30,15 @@ else:
|
|||
Batch = type(None)
|
||||
|
||||
|
||||
def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None):
|
||||
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
|
||||
if device is None:
|
||||
raise MisconfigurationException(
|
||||
"device (torch.device) should be provided."
|
||||
)
|
||||
raise MisconfigurationException("device (torch.device) should be provided.")
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def from_numpy(value, device: torch.device = None):
|
||||
if device is None:
|
||||
raise MisconfigurationException(
|
||||
"device (torch.device) should be provided."
|
||||
)
|
||||
raise MisconfigurationException("device (torch.device) should be provided.")
|
||||
return torch.from_numpy(value).to(device)
|
||||
|
||||
|
||||
|
@ -55,8 +51,14 @@ CONVERSION_DTYPES = [
|
|||
]
|
||||
|
||||
|
||||
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
|
||||
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
|
||||
def apply_to_collection(
|
||||
data: Any,
|
||||
dtype: Union[type, tuple],
|
||||
function: Callable,
|
||||
*args,
|
||||
wrong_dtype: Optional[Union[type, tuple]] = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Recursively applies a function to all elements of a certain dtype.
|
||||
|
||||
|
@ -80,8 +82,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
|
|||
|
||||
# Recursively apply to collection items
|
||||
if isinstance(data, Mapping):
|
||||
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
|
||||
for k, v in data.items()})
|
||||
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
|
||||
|
||||
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
|
||||
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
|
||||
|
@ -159,9 +160,7 @@ def move_data_to_device(batch: Any, device: torch.device):
|
|||
|
||||
def convert_to_tensors(data, device: torch.device = None):
|
||||
if device is None:
|
||||
raise MisconfigurationException(
|
||||
"device (torch.device) should be provided."
|
||||
)
|
||||
raise MisconfigurationException("device (torch.device) should be provided.")
|
||||
for src_dtype, conversion_func in CONVERSION_DTYPES:
|
||||
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))
|
||||
return data
|
||||
|
|
|
@ -57,9 +57,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
|
|||
"""Parse CLI arguments, required for custom bool types."""
|
||||
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
||||
|
||||
types_default = {
|
||||
arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)
|
||||
}
|
||||
types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)}
|
||||
|
||||
modified_args = {}
|
||||
for k, v in vars(args).items():
|
||||
|
@ -130,7 +128,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
|||
try:
|
||||
arg_types = tuple(arg_type.__args__)
|
||||
except AttributeError:
|
||||
arg_types = (arg_type,)
|
||||
arg_types = (arg_type, )
|
||||
|
||||
name_type_default.append((arg, arg_types, arg_default))
|
||||
|
||||
|
@ -156,7 +154,10 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
|||
>>> parser = Trainer.add_argparse_args(parser)
|
||||
>>> args = parser.parse_args([])
|
||||
"""
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
|
||||
parser = ArgumentParser(
|
||||
parents=[parent_parser],
|
||||
add_help=False,
|
||||
)
|
||||
|
||||
blacklist = ['kwargs']
|
||||
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
|
||||
|
@ -164,9 +165,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
|||
allowed_types = (str, int, float, bool)
|
||||
|
||||
args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__)
|
||||
for arg, arg_types, arg_default in (
|
||||
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
|
||||
):
|
||||
for arg, arg_types, arg_default in (at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names):
|
||||
arg_types = [at for at in allowed_types if at in arg_types]
|
||||
if not arg_types:
|
||||
# skip argument with not supported type
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from warnings import warn
|
||||
|
||||
warn("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning)
|
||||
warn(
|
||||
"`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning
|
||||
)
|
||||
|
||||
from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401
|
||||
|
|
|
@ -32,8 +32,9 @@ def has_len(dataloader: DataLoader) -> bool:
|
|||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError('`Dataloader` returned 0 length.'
|
||||
' Please make sure that your Dataloader at least returns 1 batch')
|
||||
raise ValueError(
|
||||
'`Dataloader` returned 0 length. Please make sure that your Dataloader at least returns 1 batch'
|
||||
)
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
|
|
|
@ -54,12 +54,12 @@ class InternalDebugger(object):
|
|||
self.dataloader_sequence_calls = []
|
||||
|
||||
def track_event(
|
||||
self,
|
||||
evt_type: str,
|
||||
evt_value: Any = None,
|
||||
global_rank: Optional[int] = None,
|
||||
local_rank: Optional[int] = None,
|
||||
comment: str = ''
|
||||
self,
|
||||
evt_type: str,
|
||||
evt_value: Any = None,
|
||||
global_rank: Optional[int] = None,
|
||||
local_rank: Optional[int] = None,
|
||||
comment: str = ''
|
||||
) -> None:
|
||||
self.events.append({
|
||||
"timestamp": time.time(),
|
||||
|
|
|
@ -139,10 +139,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]:
|
|||
all_available_gpus = _get_all_available_gpus()
|
||||
for gpu in gpus:
|
||||
if gpu not in all_available_gpus:
|
||||
raise MisconfigurationException(f"""
|
||||
You requested GPUs: {gpus}
|
||||
But your machine only has: {all_available_gpus}
|
||||
""")
|
||||
raise MisconfigurationException(
|
||||
f"You requested GPUs: {gpus}\n But your machine only has: {all_available_gpus}"
|
||||
)
|
||||
return gpus
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from pytorch_lightning import _logger as log
|
|||
if torch.distributed.is_available():
|
||||
from torch.distributed import group, ReduceOp
|
||||
else:
|
||||
|
||||
class ReduceOp:
|
||||
SUM = None
|
||||
|
||||
|
@ -108,7 +109,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None)
|
|||
|
||||
|
||||
def sync_ddp_if_available(
|
||||
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
result: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce a tensor across worker processes during distributed training
|
||||
|
@ -127,7 +130,9 @@ def sync_ddp_if_available(
|
|||
|
||||
|
||||
def sync_ddp(
|
||||
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
result: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
@ -162,13 +167,12 @@ def sync_ddp(
|
|||
|
||||
|
||||
class AllGatherGrad(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, group=group.WORLD):
|
||||
ctx.group = group
|
||||
|
||||
gathered_tensor = [
|
||||
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
|
||||
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
|
||||
gathered_tensor = torch.stack(gathered_tensor, dim=0)
|
||||
|
@ -179,12 +183,7 @@ class AllGatherGrad(torch.autograd.Function):
|
|||
def backward(ctx, *grad_output):
|
||||
grad_output = torch.cat(grad_output)
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
grad_output,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
async_op=False,
|
||||
group=ctx.group
|
||||
)
|
||||
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
|
||||
|
||||
return grad_output[torch.distributed.get_rank()]
|
||||
|
||||
|
|
|
@ -12,5 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class MisconfigurationException(Exception):
|
||||
pass
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from warnings import warn
|
||||
|
||||
warn("`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning)
|
||||
warn(
|
||||
"`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning
|
||||
)
|
||||
|
||||
from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401
|
||||
|
|
|
@ -235,8 +235,10 @@ def lightning_getattr(model, attribute):
|
|||
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
|
||||
attr = getattr(trainer.datamodule, attribute)
|
||||
else:
|
||||
raise ValueError(f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.')
|
||||
raise ValueError(
|
||||
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
|
||||
' nor the datamodule.'
|
||||
)
|
||||
return attr
|
||||
|
||||
|
||||
|
@ -246,8 +248,10 @@ def lightning_setattr(model, attribute, value):
|
|||
Will also set the attribute on datamodule, if it exists.
|
||||
"""
|
||||
if not lightning_hasattr(model, attribute):
|
||||
raise ValueError(f'{attribute} is neither stored in the model namespace'
|
||||
' nor the `hparams` namespace/dict, nor the datamodule.')
|
||||
raise ValueError(
|
||||
f'The {attribute} is neither stored in the model namespace nor the `hparams` namespace/dict,'
|
||||
' nor the datamodule.'
|
||||
)
|
||||
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helper functions to help with reproducibility of models. """
|
||||
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from warnings import warn
|
||||
|
||||
warn("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning)
|
||||
warn(
|
||||
"`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning
|
||||
)
|
||||
|
||||
from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401
|
||||
|
|
|
@ -34,6 +34,7 @@ def inner_f(queue, func, *args, **kwargs): # pragma: no cover
|
|||
|
||||
|
||||
def pl_multi_process(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
queue = Queue()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from warnings import warn
|
||||
|
||||
warn("`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning)
|
||||
warn(
|
||||
"`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4",
|
||||
DeprecationWarning
|
||||
)
|
||||
|
||||
from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401
|
||||
|
|
Loading…
Reference in New Issue