Clean-up dtype management (#14823)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Carlos Mocholí 2022-09-22 02:07:36 +02:00 committed by GitHub
parent 364dcba382
commit 7e803ba53e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 90 additions and 161 deletions

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import torch
@ -18,10 +19,12 @@ from lightning_lite.utilities.enums import PrecisionType
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if torch.is_floating_point(tensor):
if precision == PrecisionType.HALF:
return tensor.half()
if precision == PrecisionType.BFLOAT:
return tensor.bfloat16()
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor
def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor

View File

@ -47,67 +47,10 @@ class _DeviceDtypeModuleMixin(Module):
return device
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
Note:
This method modifies the module in-place.
Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns:
Module: self
Example::
>>> from torch import Tensor
>>> class ExampleModule(_DeviceDtypeModuleMixin):
... def __init__(self, weight: Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.device
device(type='cpu')
>>> module.dtype
torch.float16
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
self.__update_properties(device=out[0], dtype=out[1])
"""See :meth:`torch.nn.Module.to`."""
# this converts `str` device to `torch.device`
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
self.__update_properties(device=device, dtype=dtype)
return super().to(*args, **kwargs)
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
@ -130,50 +73,27 @@ class _DeviceDtypeModuleMixin(Module):
return super().cuda(device=device)
def cpu(self) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.cpu`."""
self.__update_properties(device=torch.device("cpu"))
return super().cpu()
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
dst_type (type or string): the desired type
Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.type`."""
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)
def float(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``float`` datatype.
Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.float`."""
self.__update_properties(dtype=torch.float)
return super().float()
def double(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``double`` datatype.
Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.double`."""
self.__update_properties(dtype=torch.double)
return super().double()
def half(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``half`` datatype.
Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.half`."""
self.__update_properties(dtype=torch.half)
return super().half()

View File

@ -22,6 +22,7 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader
from lightning_lite.plugins import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import Strategy
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
@ -104,18 +105,16 @@ class _LiteModule(_DeviceDtypeModuleMixin):
64: torch.float64,
}
# TODO: let the precision plugin handle the conversion
to_type = precision_to_type[precision]
def _convert_float_tensor(t: Tensor) -> Tensor:
return t.to(to_type) if torch.is_floating_point(t) else t
args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
args, kwargs = apply_to_collection(
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
)
with self._precision_plugin.forward_context():
output = self._forward_module(*args, **kwargs)
to_type = torch.get_default_dtype()
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
output = apply_to_collection(
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
)
return output
@overload

View File

@ -602,7 +602,7 @@ class LightningModule(
def forward(self, *args: Any, **kwargs: Any) -> Any:
r"""
Same as :meth:`torch.nn.Module.forward()`.
Same as :meth:`torch.nn.Module.forward`.
Args:
*args: Whatever you decide to pass into the forward method.

View File

@ -21,6 +21,7 @@ from torch import FloatTensor, Tensor
from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
@ -33,15 +34,9 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
pl_module: the model to wrap
"""
@staticmethod
def _to_double_precision(data: Tensor) -> Tensor:
if data.is_floating_point():
return data.double()
return data
@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
return apply_to_collection(collection, Tensor, LightningDoublePrecisionModule._to_double_precision)
return apply_to_collection(collection, Tensor, function=_convert_fp_tensor, dst_type=torch.double)
def training_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.training_step(

View File

@ -85,17 +85,9 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase):
self.precision = precision
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
inputs = apply_to_collection(inputs, Tensor, function=self._batch_to)
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
return super().forward(*inputs, **kwargs)
def _batch_to(self, batch: Tensor) -> Tensor:
if torch.is_floating_point(batch):
if self.precision == PrecisionType.HALF:
return batch.half()
elif self.precision == PrecisionType.BFLOAT:
return batch.bfloat16()
return batch
class DeepSpeedStrategy(DDPStrategy):
strategy_name = "deepspeed"

View File

@ -17,14 +17,13 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import FloatTensor, Tensor
from torch import Tensor
from torch.utils.data import DataLoader, Sampler
import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.enums import PrecisionType
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
@ -61,21 +60,9 @@ class LightningIPUModule(_LightningModuleWrapperBase):
self.precision = precision
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision == PrecisionType.HALF:
inputs = self._move_float_tensors_to_half(inputs)
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
return super().forward(*inputs, **kwargs)
@staticmethod
def batch_to(data: Tensor) -> Tensor:
if torch.is_floating_point(data):
return data.half()
return data
def _move_float_tensors_to_half(self, batch: Any) -> Any:
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch
class IPUStrategy(ParallelStrategy):
"""Plugin for training on IPU devices."""

View File

@ -23,7 +23,7 @@ class TopModule(_DeviceDtypeModuleMixin):
@pytest.mark.parametrize(
"dst_device_str,dst_dtype",
"dst_device_str,dst_type",
[
("cpu", torch.half),
("cpu", torch.float),
@ -35,21 +35,19 @@ class TopModule(_DeviceDtypeModuleMixin):
],
)
@RunIf(min_cuda_gpus=1)
def test_submodules_device_and_dtype(dst_device_str, dst_dtype):
def test_submodules_device_and_dtype(dst_device_str, dst_type):
"""Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and
the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""
dst_device = torch.device(dst_device_str)
model = TopModule()
assert model.device == torch.device("cpu")
model = model.to(device=dst_device, dtype=dst_dtype)
model = model.to(device=dst_device, dtype=dst_type)
# nn.Module does not have these attributes
assert not hasattr(model.module, "_device")
assert not hasattr(model.module, "_dtype")
# device and dtype change should propagate down into all children
assert model.device == model.module.module.device == dst_device
assert model.dtype == model.module.module.dtype == dst_dtype
assert model.dtype == model.module.module.dtype == dst_type
@pytest.mark.parametrize(
@ -72,6 +70,16 @@ def test_cuda_device(device):
assert device.index == torch.cuda.current_device()
@RunIf(min_cuda_gpus=1)
def test_cpu_device():
model = SubSubModule().cuda()
assert model.device.type == "cuda"
assert model.device.index == 0
model.cpu()
assert model.device.type == "cpu"
assert model.device.index is None
@RunIf(min_cuda_gpus=2)
def test_cuda_current_device():
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
@ -92,3 +100,46 @@ def test_cuda_current_device():
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)
class ExampleModule(_DeviceDtypeModuleMixin):
def __init__(self, weight):
super().__init__()
self.register_buffer("weight", weight)
def test_to_combinations():
module = ExampleModule(torch.rand(3, 4))
# sanity check
assert module.weight.shape == (3, 4)
assert module.weight.dtype is torch.float32
# positional dtype
module.to(torch.double)
assert module.weight.dtype is torch.float64
# positional device
module.to("cpu", dtype=torch.half, non_blocking=True)
assert module.weight.dtype is torch.float16
assert module.device == torch.device("cpu")
assert module.dtype is torch.float16
def test_dtype_conversions():
module = ExampleModule(torch.tensor(1))
# different dtypes
assert module.weight.dtype is torch.int64
assert module.dtype is torch.float32
# `.double()` skips non floating points
module.double()
assert module.weight.dtype is torch.int64
assert module.dtype is torch.float64
# but `type` doesn't
module.type(torch.float)
assert module.weight.dtype is torch.float32
assert module.dtype is torch.float32
# now, test the rest
module.float()
assert module.weight.dtype is torch.float32
assert module.dtype is torch.float32
module.half()
assert module.weight.dtype is torch.float16
assert module.dtype is torch.float16

View File

@ -83,38 +83,20 @@ class ModelParallelBoringModelManualOptim(BoringModel):
return False
def test_deepspeed_lightning_module():
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly."""
model = BoringModel()
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
module = LightningDeepSpeedModule(model, precision=16)
module.half()
assert module.dtype == torch.half
assert model.dtype == torch.half
module.to(torch.double)
assert module.dtype == torch.double
assert model.dtype == torch.double
@RunIf(min_cuda_gpus=1)
def test_deepspeed_lightning_module_precision():
"""Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision
16."""
model = BoringModel()
with pytest.deprecated_call(match="`LightningDeepSpeedModule` has been deprecated in v1.7.1"):
module = LightningDeepSpeedModule(model, precision=16)
module.cuda().half()
module.to(device="cuda", dtype=torch.half)
assert module.dtype == torch.half
assert model.dtype == torch.half
x = torch.randn((1, 32), dtype=torch.float).cuda()
x = torch.randn((1, 32), device="cuda", dtype=torch.float)
out = module(x)
assert out.dtype == torch.half
module.to(torch.double)