Cleanup Lite `apply_func`s utilitites (#14560)
This commit is contained in:
parent
273a9ed8c1
commit
8095e2452d
|
@ -21,25 +21,21 @@ import torch
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
|
||||
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
|
||||
|
||||
|
||||
def to_dtype_tensor(
|
||||
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
|
||||
) -> Tensor:
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> Tensor:
|
||||
return torch.from_numpy(value).to(device)
|
||||
def _from_numpy(value: np.ndarray, device: _DEVICE) -> Tensor:
|
||||
return torch.from_numpy(value).to(device) # type: ignore[arg-type]
|
||||
|
||||
|
||||
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [
|
||||
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
|
||||
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
|
||||
(int, partial(to_dtype_tensor, dtype=torch.int)),
|
||||
(float, partial(to_dtype_tensor, dtype=torch.float)),
|
||||
(np.ndarray, from_numpy),
|
||||
(bool, partial(torch.tensor, dtype=torch.uint8)),
|
||||
(int, partial(torch.tensor, dtype=torch.int)),
|
||||
(float, partial(torch.tensor, dtype=torch.float)),
|
||||
(np.ndarray, _from_numpy),
|
||||
]
|
||||
|
||||
|
||||
|
@ -70,7 +66,7 @@ class TransferableDataType(ABC):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
|
||||
def move_data_to_device(batch: Any, device: _DEVICE) -> Any:
|
||||
"""Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be
|
||||
moved and all other objects in the collection will be left untouched.
|
||||
|
||||
|
@ -105,12 +101,13 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
|
|||
return apply_to_collection(batch, dtype=TransferableDataType, function=batch_to)
|
||||
|
||||
|
||||
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
|
||||
def convert_to_tensors(data: Any, device: _DEVICE) -> Any:
|
||||
# convert non-tensors
|
||||
for src_dtype, conversion_func in CONVERSION_DTYPES:
|
||||
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
|
||||
|
||||
def _move_to_device_and_make_contiguous(t: Tensor, device: Union[str, torch.device]) -> Tensor:
|
||||
return t.to(device).contiguous()
|
||||
def _move_to_device_and_make_contiguous(t: Tensor, device: _DEVICE) -> Tensor:
|
||||
return t.to(device).contiguous() # type: ignore[arg-type]
|
||||
|
||||
data = apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)
|
||||
return data
|
||||
# make sure existing tensors are in the correct device, also contiguous
|
||||
return apply_to_collection(data, Tensor, _move_to_device_and_make_contiguous, device=device)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
_DEVICE = Union[torch.device, str, int]
|
|
@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383))
|
||||
|
||||
|
||||
- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475))
|
||||
- Integrate the `lightning_utilities` package ([#14475](https://github.com/Lightning-AI/lightning/issues/14475), [#14537](https://github.com/Lightning-AI/lightning/issues/14537))
|
||||
|
||||
|
||||
|
||||
|
@ -84,15 +84,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated all functions in `pytorch_lightning.utilities.xla_device` in favor of `lightning_lite.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514))
|
||||
|
||||
|
||||
|
||||
- Deprecated all functions in `pytorch_lightning.utilities.cloud_io` in favor of `lightning_lite.utilities.cloud_io` ([#14515](https://github.com/Lightning-AI/lightning/pull/14515))
|
||||
|
||||
|
||||
|
||||
- Deprecated the functions in `pytorch_lightning.utilities.apply_func` in favor of `lightning_utilities.core.apply_func` ([#14516](https://github.com/Lightning-AI/lightning/pull/14516), [#14537](https://github.com/Lightning-AI/lightning/pull/14537))
|
||||
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection as new_apply_to_collection
|
||||
from lightning_utilities.core.apply_func import apply_to_collections as new_apply_to_collections
|
||||
|
||||
from lightning_lite.utilities.apply_func import _from_numpy
|
||||
from lightning_lite.utilities.apply_func import convert_to_tensors as new_convert_to_tensors
|
||||
from lightning_lite.utilities.apply_func import from_numpy as new_from_numpy
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device as new_move_data_to_device
|
||||
from lightning_lite.utilities.apply_func import to_dtype_tensor as new_to_dtype_tensor
|
||||
from lightning_lite.utilities.apply_func import TransferableDataType as NewTransferableDataType
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -62,9 +62,9 @@ def convert_to_tensors(*args: Any, **kwargs: Any) -> Any:
|
|||
def from_numpy(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.apply_func.from_numpy` has been deprecated in v1.8.0 and will be"
|
||||
" removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.from_numpy` instead."
|
||||
" removed in v1.10.0. Please use `torch.from_numpy().to()` instead."
|
||||
)
|
||||
return new_from_numpy(*args, **kwargs)
|
||||
return _from_numpy(*args, **kwargs)
|
||||
|
||||
|
||||
def move_data_to_device(*args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -78,9 +78,9 @@ def move_data_to_device(*args: Any, **kwargs: Any) -> Any:
|
|||
def to_dtype_tensor(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.apply_func.to_dtype_tensor` has been deprecated in v1.8.0 and will be"
|
||||
" removed in v1.10.0. Please use `lightning_lite.utilities.apply_func.to_dtype_tensor` instead."
|
||||
" removed in v1.10.0. Please use `torch.tensor` instead."
|
||||
)
|
||||
return new_to_dtype_tensor(*args, **kwargs)
|
||||
return torch.tensor(*args, **kwargs)
|
||||
|
||||
|
||||
class TransferableDataType(NewTransferableDataType):
|
||||
|
|
Loading…
Reference in New Issue