Clean-up dtype management (#14823)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
364dcba382
commit
7e803ba53e
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -18,10 +19,12 @@ from lightning_lite.utilities.enums import PrecisionType
|
|||
|
||||
|
||||
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
|
||||
if torch.is_floating_point(tensor):
|
||||
if precision == PrecisionType.HALF:
|
||||
return tensor.half()
|
||||
if precision == PrecisionType.BFLOAT:
|
||||
return tensor.bfloat16()
|
||||
|
||||
if precision == PrecisionType.HALF:
|
||||
return _convert_fp_tensor(tensor, torch.half)
|
||||
if precision == PrecisionType.BFLOAT:
|
||||
return _convert_fp_tensor(tensor, torch.bfloat16)
|
||||
return tensor
|
||||
|
||||
|
||||
def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
|
||||
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
|
||||
|
|
|
@ -47,67 +47,10 @@ class _DeviceDtypeModuleMixin(Module):
|
|||
return device
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
|
||||
"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
.. function:: to(device=None, dtype=None, non_blocking=False)
|
||||
.. function:: to(dtype, non_blocking=False)
|
||||
.. function:: to(tensor, non_blocking=False)
|
||||
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
|
||||
floating point desired :attr:`dtype` s. In addition, this method will
|
||||
only cast the floating point parameters and buffers to :attr:`dtype`
|
||||
(if given). The integral parameters and buffers will be moved
|
||||
:attr:`device`, if that is given, but with dtypes unchanged. When
|
||||
:attr:`non_blocking` is set, it tries to convert/move asynchronously
|
||||
with respect to the host if possible, e.g., moving CPU Tensors with
|
||||
pinned memory to CUDA devices.
|
||||
See below for examples.
|
||||
|
||||
Note:
|
||||
This method modifies the module in-place.
|
||||
|
||||
Args:
|
||||
device: the desired device of the parameters
|
||||
and buffers in this module
|
||||
dtype: the desired floating point type of
|
||||
the floating point parameters and buffers in this module
|
||||
tensor: Tensor whose dtype and device are the desired
|
||||
dtype and device for all parameters and buffers in this module
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
|
||||
Example::
|
||||
>>> from torch import Tensor
|
||||
>>> class ExampleModule(_DeviceDtypeModuleMixin):
|
||||
... def __init__(self, weight: Tensor):
|
||||
... super().__init__()
|
||||
... self.register_buffer('weight', weight)
|
||||
>>> _ = torch.manual_seed(0)
|
||||
>>> module = ExampleModule(torch.rand(3, 4))
|
||||
>>> module.weight #doctest: +ELLIPSIS
|
||||
tensor([[...]])
|
||||
>>> module.to(torch.double)
|
||||
ExampleModule()
|
||||
>>> module.weight #doctest: +ELLIPSIS
|
||||
tensor([[...]], dtype=torch.float64)
|
||||
>>> cpu = torch.device('cpu')
|
||||
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
|
||||
ExampleModule()
|
||||
>>> module.weight #doctest: +ELLIPSIS
|
||||
tensor([[...]], dtype=torch.float16)
|
||||
>>> module.to(cpu)
|
||||
ExampleModule()
|
||||
>>> module.weight #doctest: +ELLIPSIS
|
||||
tensor([[...]], dtype=torch.float16)
|
||||
>>> module.device
|
||||
device(type='cpu')
|
||||
>>> module.dtype
|
||||
torch.float16
|
||||
"""
|
||||
# there is diff nb vars in PT 1.5
|
||||
out = torch._C._nn._parse_to(*args, **kwargs)
|
||||
self.__update_properties(device=out[0], dtype=out[1])
|
||||
"""See :meth:`torch.nn.Module.to`."""
|
||||
# this converts `str` device to `torch.device`
|
||||
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
|
||||
self.__update_properties(device=device, dtype=dtype)
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
|
||||
|
@ -130,50 +73,27 @@ class _DeviceDtypeModuleMixin(Module):
|
|||
return super().cuda(device=device)
|
||||
|
||||
def cpu(self) -> Self: # type: ignore[valid-type]
|
||||
"""Moves all model parameters and buffers to the CPU.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
"""See :meth:`torch.nn.Module.cpu`."""
|
||||
self.__update_properties(device=torch.device("cpu"))
|
||||
return super().cpu()
|
||||
|
||||
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
Arguments:
|
||||
dst_type (type or string): the desired type
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
"""See :meth:`torch.nn.Module.type`."""
|
||||
self.__update_properties(dtype=dst_type)
|
||||
return super().type(dst_type=dst_type)
|
||||
|
||||
def float(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``float`` datatype.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
"""See :meth:`torch.nn.Module.float`."""
|
||||
self.__update_properties(dtype=torch.float)
|
||||
return super().float()
|
||||
|
||||
def double(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``double`` datatype.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
"""See :meth:`torch.nn.Module.double`."""
|
||||
self.__update_properties(dtype=torch.double)
|
||||
return super().double()
|
||||
|
||||
def half(self) -> Self: # type: ignore[valid-type]
|
||||
"""Casts all floating point parameters and buffers to ``half`` datatype.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
"""See :meth:`torch.nn.Module.half`."""
|
||||
self.__update_properties(dtype=torch.half)
|
||||
return super().half()
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch.optim import Optimizer
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_lite.plugins import Precision
|
||||
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
|
||||
from lightning_lite.strategies import Strategy
|
||||
from lightning_lite.utilities import move_data_to_device
|
||||
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
|
@ -104,18 +105,16 @@ class _LiteModule(_DeviceDtypeModuleMixin):
|
|||
64: torch.float64,
|
||||
}
|
||||
# TODO: let the precision plugin handle the conversion
|
||||
to_type = precision_to_type[precision]
|
||||
|
||||
def _convert_float_tensor(t: Tensor) -> Tensor:
|
||||
return t.to(to_type) if torch.is_floating_point(t) else t
|
||||
|
||||
args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
|
||||
args, kwargs = apply_to_collection(
|
||||
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
|
||||
)
|
||||
|
||||
with self._precision_plugin.forward_context():
|
||||
output = self._forward_module(*args, **kwargs)
|
||||
|
||||
to_type = torch.get_default_dtype()
|
||||
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
|
||||
output = apply_to_collection(
|
||||
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
|
||||
)
|
||||
return output
|
||||
|
||||
@overload
|
||||
|
|
|
@ -602,7 +602,7 @@ class LightningModule(
|
|||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""
|
||||
Same as :meth:`torch.nn.Module.forward()`.
|
||||
Same as :meth:`torch.nn.Module.forward`.
|
||||
|
||||
Args:
|
||||
*args: Whatever you decide to pass into the forward method.
|
||||
|
|
|
@ -21,6 +21,7 @@ from torch import FloatTensor, Tensor
|
|||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
|
||||
|
@ -33,15 +34,9 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
|
|||
pl_module: the model to wrap
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _to_double_precision(data: Tensor) -> Tensor:
|
||||
if data.is_floating_point():
|
||||
return data.double()
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _move_float_tensors_to_double(collection: Any) -> Any:
|
||||
return apply_to_collection(collection, Tensor, LightningDoublePrecisionModule._to_double_precision)
|
||||
return apply_to_collection(collection, Tensor, function=_convert_fp_tensor, dst_type=torch.double)
|
||||
|
||||
def training_step(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.module.training_step(
|
||||
|
|
|
@ -85,17 +85,9 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase):
|
|||
self.precision = precision
|
||||
|
||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||
inputs = apply_to_collection(inputs, Tensor, function=self._batch_to)
|
||||
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
|
||||
return super().forward(*inputs, **kwargs)
|
||||
|
||||
def _batch_to(self, batch: Tensor) -> Tensor:
|
||||
if torch.is_floating_point(batch):
|
||||
if self.precision == PrecisionType.HALF:
|
||||
return batch.half()
|
||||
elif self.precision == PrecisionType.BFLOAT:
|
||||
return batch.bfloat16()
|
||||
return batch
|
||||
|
||||
|
||||
class DeepSpeedStrategy(DDPStrategy):
|
||||
strategy_name = "deepspeed"
|
||||
|
|
|
@ -17,14 +17,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import FloatTensor, Tensor
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Sampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.plugins.precision.utils import _fp_to_half
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
|
@ -61,21 +60,9 @@ class LightningIPUModule(_LightningModuleWrapperBase):
|
|||
self.precision = precision
|
||||
|
||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||
if self.precision == PrecisionType.HALF:
|
||||
inputs = self._move_float_tensors_to_half(inputs)
|
||||
|
||||
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
|
||||
return super().forward(*inputs, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def batch_to(data: Tensor) -> Tensor:
|
||||
if torch.is_floating_point(data):
|
||||
return data.half()
|
||||
return data
|
||||
|
||||
def _move_float_tensors_to_half(self, batch: Any) -> Any:
|
||||
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
|
||||
return batch
|
||||
|
||||
|
||||
class IPUStrategy(ParallelStrategy):
|
||||
"""Plugin for training on IPU devices."""
|
||||
|
|
|
@ -23,7 +23,7 @@ class TopModule(_DeviceDtypeModuleMixin):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dst_device_str,dst_dtype",
|
||||
"dst_device_str,dst_type",
|
||||
[
|
||||
("cpu", torch.half),
|
||||
("cpu", torch.float),
|
||||
|
@ -35,21 +35,19 @@ class TopModule(_DeviceDtypeModuleMixin):
|
|||
],
|
||||
)
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_submodules_device_and_dtype(dst_device_str, dst_dtype):
|
||||
def test_submodules_device_and_dtype(dst_device_str, dst_type):
|
||||
"""Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and
|
||||
the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""
|
||||
|
||||
dst_device = torch.device(dst_device_str)
|
||||
|
||||
model = TopModule()
|
||||
assert model.device == torch.device("cpu")
|
||||
model = model.to(device=dst_device, dtype=dst_dtype)
|
||||
model = model.to(device=dst_device, dtype=dst_type)
|
||||
# nn.Module does not have these attributes
|
||||
assert not hasattr(model.module, "_device")
|
||||
assert not hasattr(model.module, "_dtype")
|
||||
# device and dtype change should propagate down into all children
|
||||
assert model.device == model.module.module.device == dst_device
|
||||
assert model.dtype == model.module.module.dtype == dst_dtype
|
||||
assert model.dtype == model.module.module.dtype == dst_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -72,6 +70,16 @@ def test_cuda_device(device):
|
|||
assert device.index == torch.cuda.current_device()
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_cpu_device():
|
||||
model = SubSubModule().cuda()
|
||||
assert model.device.type == "cuda"
|
||||
assert model.device.index == 0
|
||||
model.cpu()
|
||||
assert model.device.type == "cpu"
|
||||
assert model.device.index is None
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2)
|
||||
def test_cuda_current_device():
|
||||
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
|
||||
|
@ -92,3 +100,46 @@ def test_cuda_current_device():
|
|||
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
|
||||
assert model.device == torch.device("cuda", 1)
|
||||
assert model.layer.weight.device == torch.device("cuda", 1)
|
||||
|
||||
|
||||
class ExampleModule(_DeviceDtypeModuleMixin):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", weight)
|
||||
|
||||
|
||||
def test_to_combinations():
|
||||
module = ExampleModule(torch.rand(3, 4))
|
||||
# sanity check
|
||||
assert module.weight.shape == (3, 4)
|
||||
assert module.weight.dtype is torch.float32
|
||||
# positional dtype
|
||||
module.to(torch.double)
|
||||
assert module.weight.dtype is torch.float64
|
||||
# positional device
|
||||
module.to("cpu", dtype=torch.half, non_blocking=True)
|
||||
assert module.weight.dtype is torch.float16
|
||||
assert module.device == torch.device("cpu")
|
||||
assert module.dtype is torch.float16
|
||||
|
||||
|
||||
def test_dtype_conversions():
|
||||
module = ExampleModule(torch.tensor(1))
|
||||
# different dtypes
|
||||
assert module.weight.dtype is torch.int64
|
||||
assert module.dtype is torch.float32
|
||||
# `.double()` skips non floating points
|
||||
module.double()
|
||||
assert module.weight.dtype is torch.int64
|
||||
assert module.dtype is torch.float64
|
||||
# but `type` doesn't
|
||||
module.type(torch.float)
|
||||
assert module.weight.dtype is torch.float32
|
||||
assert module.dtype is torch.float32
|
||||
# now, test the rest
|
||||
module.float()
|
||||
assert module.weight.dtype is torch.float32
|
||||
assert module.dtype is torch.float32
|
||||
module.half()
|
||||
assert module.weight.dtype is torch.float16
|
||||
assert module.dtype is torch.float16
|
||||
|
|
|
@ -83,38 +83,20 @@ class ModelParallelBoringModelManualOptim(BoringModel):
|
|||
return False
|
||||
|
||||
|
||||
def test_deepspeed_lightning_module():
|
||||
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly."""
|
||||
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
|
||||
module = LightningDeepSpeedModule(model, precision=16)
|
||||
|
||||
module.half()
|
||||
assert module.dtype == torch.half
|
||||
assert model.dtype == torch.half
|
||||
|
||||
module.to(torch.double)
|
||||
assert module.dtype == torch.double
|
||||
assert model.dtype == torch.double
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_deepspeed_lightning_module_precision():
|
||||
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision
|
||||
16."""
|
||||
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
|
||||
module = LightningDeepSpeedModule(model, precision=16)
|
||||
|
||||
module.cuda().half()
|
||||
module.to(device="cuda", dtype=torch.half)
|
||||
assert module.dtype == torch.half
|
||||
assert model.dtype == torch.half
|
||||
|
||||
x = torch.randn((1, 32), dtype=torch.float).cuda()
|
||||
x = torch.randn((1, 32), device="cuda", dtype=torch.float)
|
||||
out = module(x)
|
||||
|
||||
assert out.dtype == torch.half
|
||||
|
||||
module.to(torch.double)
|
||||
|
|
Loading…
Reference in New Issue