Cleanup Lite `apply_func`s utilitites (#14560)

This commit is contained in:
Carlos Mocholí 2022-09-07 00:35:33 +02:00 committed by GitHub
parent 273a9ed8c1
commit 8095e2452d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 28 deletions

View File

@ -21,25 +21,21 @@ import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from lightning_lite.utilities.types import _DEVICE
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
def to_dtype_tensor(
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
) -> Tensor:
return torch.tensor(value, dtype=dtype, device=device)
def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> Tensor:
return torch.from_numpy(value).to(device)
def _from_numpy(value: np.ndarray, device: _DEVICE) -> Tensor:
return torch.from_numpy(value).to(device) # type: ignore[arg-type]
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
(float, partial(to_dtype_tensor, dtype=torch.float)),
(np.ndarray, from_numpy),
(bool, partial(torch.tensor, dtype=torch.uint8)),
(int, partial(torch.tensor, dtype=torch.int)),
(float, partial(torch.tensor, dtype=torch.float)),
(np.ndarray, _from_numpy),
]
@ -70,7 +66,7 @@ class TransferableDataType(ABC):
return NotImplemented
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
def move_data_to_device(batch: Any, device: _DEVICE) -> Any:
"""Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be
moved and all other objects in the collection will be left untouched.
@ -105,12 +101,13 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
return apply_to_collection(batch, dtype=TransferableDataType, function=batch_to)
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
def convert_to_tensors(data: Any, device: _DEVICE) -> Any:
# convert non-tensors
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
def _move_to_device_and_make_contiguous(t: Tensor, device: Union[str, torch.device]) -> Tensor:
return t.to(device).contiguous()
def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor:
return t.to(device).contiguous() # type: ignore[arg-type]
data = apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)
return data
# make sure existing tensors are in the correct device, also contiguous
return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)

View File

@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import torch
_DEVICE = Union[torch.device, str, int]

View File

@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383))
- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475))
- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475), [#14537](https://github.com/Lightning-AI/lightning/issues/14537))
@ -84,15 +84,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated all functions in `pytorch_lightning.utilities.xla_device` in favor of `lightning_lite.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514))
- Deprecated all functions in `pytorch_lightning.utilities.cloud_io` in favor of `lightning_lite.utilities.cloud_io` ([#14515](https://github.com/Lightning-AI/lightning/pull/14515))
- Deprecated the functions in `pytorch_lightning.utilities.apply_func` in favor of `lightning_utilities.core.apply_func` ([#14516](https://github.com/Lightning-AI/lightning/pull/14516), [#14537](https://github.com/Lightning-AI/lightning/pull/14537))
### Removed
- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))

View File

@ -15,13 +15,13 @@
from typing import Any
import torch
from lightning_utilities.core.apply_func import apply_to_collection as new_apply_to_collection
from lightning_utilities.core.apply_func import apply_to_collections as new_apply_to_collections
from lightning_lite.utilities.apply_func import _from_numpy
from lightning_lite.utilities.apply_func import convert_to_tensors as new_convert_to_tensors
from lightning_lite.utilities.apply_func import from_numpy as new_from_numpy
from lightning_lite.utilities.apply_func import move_data_to_device as new_move_data_to_device
from lightning_lite.utilities.apply_func import to_dtype_tensor as new_to_dtype_tensor
from lightning_lite.utilities.apply_func import TransferableDataType as NewTransferableDataType
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -62,9 +62,9 @@ def convert_to_tensors(*args: Any, **kwargs: Any) -> Any:
def from_numpy(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.apply_func.from_numpy` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.from_numpy` instead."
" removed in v1.10.0. Please use `torch.from_numpy().to()` instead."
)
return new_from_numpy(*args, **kwargs)
return _from_numpy(*args, **kwargs)
def move_data_to_device(*args: Any, **kwargs: Any) -> Any:
@ -78,9 +78,9 @@ def move_data_to_device(*args: Any, **kwargs: Any) -> Any:
def to_dtype_tensor(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.apply_func.to_dtype_tensor` has been deprecated in v1.8.0 and will be"
" removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.to_dtype_tensor` instead."
" removed in v1.10.0. Please use `torch.tensor` instead."
)
return new_to_dtype_tensor(*args, **kwargs)
return torch.tensor(*args, **kwargs)
class TransferableDataType(NewTransferableDataType):