From 8095e2452d519167944e9924819b43710f45b1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 7 Sep 2022 00:35:33 +0200 Subject: [PATCH] Cleanup Lite `apply_func`s utilitites (#14560) --- src/lightning_lite/utilities/apply_func.py | 33 +++++++++---------- src/lightning_lite/utilities/types.py | 18 ++++++++++ src/pytorch_lightning/CHANGELOG.md | 5 +-- src/pytorch_lightning/utilities/apply_func.py | 12 +++---- 4 files changed, 40 insertions(+), 28 deletions(-) create mode 100644 src/lightning_lite/utilities/types.py diff --git a/src/lightning_lite/utilities/apply_func.py b/src/lightning_lite/utilities/apply_func.py index c76fe01985..a3a203776b 100644 --- a/src/lightning_lite/utilities/apply_func.py +++ b/src/lightning_lite/utilities/apply_func.py @@ -21,25 +21,21 @@ import torch from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor +from lightning_lite.utilities.types import _DEVICE + _BLOCKING_DEVICE_TYPES = ("cpu", "mps") -def to_dtype_tensor( - value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device] -) -> Tensor: - return torch.tensor(value, dtype=dtype, device=device) - - -def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> Tensor: - return torch.from_numpy(value).to(device) +def _from_numpy(value: np.ndarray, device: _DEVICE) -> Tensor: + return torch.from_numpy(value).to(device) # type: ignore[arg-type] CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group - (bool, partial(to_dtype_tensor, dtype=torch.uint8)), - (int, partial(to_dtype_tensor, dtype=torch.int)), - (float, partial(to_dtype_tensor, dtype=torch.float)), - (np.ndarray, from_numpy), + (bool, partial(torch.tensor, dtype=torch.uint8)), + (int, partial(torch.tensor, dtype=torch.int)), + (float, partial(torch.tensor, dtype=torch.float)), + (np.ndarray, _from_numpy), ] @@ -70,7 +66,7 @@ class TransferableDataType(ABC): return NotImplemented -def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: +def move_data_to_device(batch: Any, device: _DEVICE) -> Any: """Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. @@ -105,12 +101,13 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: return apply_to_collection(batch, dtype=TransferableDataType, function=batch_to) -def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any: +def convert_to_tensors(data: Any, device: _DEVICE) -> Any: + # convert non-tensors for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, conversion_func, device=device) - def _move_to_device_and_make_contiguous(t: Tensor, device: Union[str, torch.device]) -> Tensor: - return t.to(device).contiguous() + def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor: + return t.to(device).contiguous() # type: ignore[arg-type] - data = apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device) - return data + # make sure existing tensors are in the correct device, also contiguous + return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device) diff --git a/src/lightning_lite/utilities/types.py b/src/lightning_lite/utilities/types.py new file mode 100644 index 0000000000..900154e69c --- /dev/null +++ b/src/lightning_lite/utilities/types.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import torch + +_DEVICE = Union[torch.device, str, int] diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index c3534aa3d6..04aeed3fb4 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383)) -- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475)) +- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475), [#14537](https://github.com/Lightning-AI/lightning/issues/14537)) @@ -84,15 +84,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated all functions in `pytorch_lightning.utilities.xla_device` in favor of `lightning_lite.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514)) - - Deprecated all functions in `pytorch_lightning.utilities.cloud_io` in favor of `lightning_lite.utilities.cloud_io` ([#14515](https://github.com/Lightning-AI/lightning/pull/14515)) - - Deprecated the functions in `pytorch_lightning.utilities.apply_func` in favor of `lightning_utilities.core.apply_func` ([#14516](https://github.com/Lightning-AI/lightning/pull/14516), [#14537](https://github.com/Lightning-AI/lightning/pull/14537)) - ### Removed - Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011)) diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index f4b1bfedef..e7c8fedb48 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -15,13 +15,13 @@ from typing import Any +import torch from lightning_utilities.core.apply_func import apply_to_collection as new_apply_to_collection from lightning_utilities.core.apply_func import apply_to_collections as new_apply_to_collections +from lightning_lite.utilities.apply_func import _from_numpy from lightning_lite.utilities.apply_func import convert_to_tensors as new_convert_to_tensors -from lightning_lite.utilities.apply_func import from_numpy as new_from_numpy from lightning_lite.utilities.apply_func import move_data_to_device as new_move_data_to_device -from lightning_lite.utilities.apply_func import to_dtype_tensor as new_to_dtype_tensor from lightning_lite.utilities.apply_func import TransferableDataType as NewTransferableDataType from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -62,9 +62,9 @@ def convert_to_tensors(*args: Any, **kwargs: Any) -> Any: def from_numpy(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.apply_func.from_numpy` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.from_numpy` instead." + " removed in v1.10.0. Please use `torch.from_numpy().to()` instead." ) - return new_from_numpy(*args, **kwargs) + return _from_numpy(*args, **kwargs) def move_data_to_device(*args: Any, **kwargs: Any) -> Any: @@ -78,9 +78,9 @@ def move_data_to_device(*args: Any, **kwargs: Any) -> Any: def to_dtype_tensor(*args: Any, **kwargs: Any) -> Any: rank_zero_deprecation( "`pytorch_lightning.utilities.apply_func.to_dtype_tensor` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.to_dtype_tensor` instead." + " removed in v1.10.0. Please use `torch.tensor` instead." ) - return new_to_dtype_tensor(*args, **kwargs) + return torch.tensor(*args, **kwargs) class TransferableDataType(NewTransferableDataType):