2021-10-29 21:46:39 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
from unittest.mock import ANY, Mock
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2021-11-01 18:33:13 +00:00
|
|
|
from torch.utils.data.dataloader import DataLoader
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2021-11-16 17:26:46 +00:00
|
|
|
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
|
2021-10-29 21:46:39 +00:00
|
|
|
from pytorch_lightning.lite import LightningLite
|
|
|
|
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
|
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyLite(LightningLite):
|
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def test_lite_module_wraps():
|
|
|
|
"""Test that the wrapped module is accessible via the property."""
|
|
|
|
module = Mock()
|
|
|
|
assert _LiteModule(module, Mock()).module is module
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_gpus=1)
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"precision, input_type, expected_type",
|
|
|
|
[
|
|
|
|
(32, torch.float16, torch.float32),
|
|
|
|
(32, torch.float32, torch.float32),
|
|
|
|
(32, torch.float64, torch.float32),
|
2021-11-09 14:21:00 +00:00
|
|
|
(32, torch.int, torch.int),
|
2021-10-29 21:46:39 +00:00
|
|
|
(16, torch.float32, torch.float16),
|
|
|
|
(16, torch.float64, torch.float16),
|
2021-11-09 14:21:00 +00:00
|
|
|
(16, torch.long, torch.long),
|
2022-04-10 04:23:16 +00:00
|
|
|
pytest.param("bf16", torch.float32, torch.bfloat16, marks=RunIf(min_torch="1.10", bf16_cuda=True)),
|
|
|
|
pytest.param("bf16", torch.float64, torch.bfloat16, marks=RunIf(min_torch="1.10", bf16_cuda=True)),
|
|
|
|
pytest.param("bf16", torch.bool, torch.bool, marks=RunIf(min_torch="1.10", bf16_cuda=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_lite_module_forward_conversion(precision, input_type, expected_type):
|
|
|
|
"""Test that the LiteModule performs autocasting on the input tensors and during forward()."""
|
|
|
|
lite = EmptyLite(precision=precision, accelerator="gpu", devices=1)
|
|
|
|
device = torch.device("cuda", 0)
|
|
|
|
|
|
|
|
def check_autocast(forward_input):
|
|
|
|
assert precision != 16 or torch.is_autocast_enabled()
|
|
|
|
return forward_input
|
|
|
|
|
2021-11-09 14:21:00 +00:00
|
|
|
module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
|
2021-10-29 21:46:39 +00:00
|
|
|
lite_module = _LiteModule(module, lite._precision_plugin).to(device)
|
2021-11-09 14:21:00 +00:00
|
|
|
out = lite_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
|
2021-10-29 21:46:39 +00:00
|
|
|
assert module.call_args[0][0].dtype == expected_type
|
2021-11-09 14:21:00 +00:00
|
|
|
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2021-11-16 17:26:46 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))]
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
|
|
|
def test_lite_module_device_dtype_propagation(device, dtype):
|
|
|
|
"""Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
|
|
|
|
|
|
|
|
class DeviceModule(DeviceDtypeModuleMixin):
|
|
|
|
pass
|
|
|
|
|
|
|
|
device_module = DeviceModule()
|
|
|
|
lite_module = _LiteModule(device_module, Mock())
|
|
|
|
lite_module.to(device)
|
|
|
|
assert device_module.device == device
|
|
|
|
assert lite_module.device == device
|
|
|
|
|
|
|
|
lite_module.to(dtype)
|
|
|
|
assert device_module.dtype == dtype
|
|
|
|
assert lite_module.dtype == dtype
|
|
|
|
|
|
|
|
|
2021-11-02 10:40:35 +00:00
|
|
|
def test_lite_dataloader_iterator():
|
|
|
|
"""Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
|
|
|
|
device placement)."""
|
|
|
|
dataloader = DataLoader(range(5), batch_size=2)
|
|
|
|
lite_dataloader = _LiteDataLoader(dataloader)
|
|
|
|
assert len(lite_dataloader) == len(dataloader) == 3
|
|
|
|
|
|
|
|
iterator = iter(dataloader)
|
|
|
|
lite_iterator = iter(lite_dataloader)
|
|
|
|
|
|
|
|
assert torch.equal(next(iterator), next(lite_iterator))
|
|
|
|
assert torch.equal(next(iterator), next(lite_iterator))
|
|
|
|
assert torch.equal(next(iterator), next(lite_iterator))
|
|
|
|
|
|
|
|
with pytest.raises(StopIteration):
|
|
|
|
next(iterator)
|
|
|
|
|
|
|
|
with pytest.raises(StopIteration):
|
|
|
|
next(lite_iterator)
|
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"src_device, dest_device",
|
|
|
|
[
|
|
|
|
(torch.device("cpu"), torch.device("cpu")),
|
|
|
|
pytest.param(torch.device("cpu"), torch.device("cuda", 0), marks=RunIf(min_gpus=1)),
|
|
|
|
pytest.param(torch.device("cuda", 0), torch.device("cpu"), marks=RunIf(min_gpus=1)),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_lite_dataloader_device_placement(src_device, dest_device):
|
|
|
|
"""Test that the LiteDataLoader moves data to the device in its iterator."""
|
|
|
|
sample0 = torch.tensor(0, device=src_device)
|
|
|
|
sample1 = torch.tensor(1, device=src_device)
|
|
|
|
sample2 = {"data": torch.tensor(2, device=src_device)}
|
|
|
|
sample3 = {"data": torch.tensor(3, device=src_device)}
|
2021-11-01 18:33:13 +00:00
|
|
|
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
|
|
|
|
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device)
|
2021-10-29 21:46:39 +00:00
|
|
|
iterator = iter(lite_dataloader)
|
|
|
|
|
|
|
|
batch0 = next(iterator)
|
|
|
|
assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device))
|
|
|
|
|
|
|
|
batch1 = next(iterator)
|
|
|
|
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
|
|
|
|
|
|
|
|
|
|
|
|
def test_lite_optimizer_wraps():
|
|
|
|
"""Test that the LiteOptimizer fully wraps the optimizer."""
|
|
|
|
optimizer_cls = torch.optim.SGD
|
|
|
|
optimizer = Mock(spec=optimizer_cls)
|
|
|
|
lite_optimizer = _LiteOptimizer(optimizer, Mock())
|
|
|
|
assert lite_optimizer.optimizer is optimizer
|
|
|
|
assert isinstance(lite_optimizer, optimizer_cls)
|
|
|
|
|
|
|
|
|
2021-11-27 04:54:45 +00:00
|
|
|
def test_lite_optimizer_state_dict():
|
2021-11-30 14:16:59 +00:00
|
|
|
"""Test that the LiteOptimizer calls into the strategy to collect the state."""
|
2021-11-27 04:54:45 +00:00
|
|
|
optimizer = Mock()
|
2021-11-30 14:16:59 +00:00
|
|
|
strategy = Mock()
|
|
|
|
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
|
2021-11-27 04:54:45 +00:00
|
|
|
lite_optimizer.state_dict()
|
2021-11-30 14:16:59 +00:00
|
|
|
strategy.optimizer_state.assert_called_with(optimizer)
|
2021-11-27 04:54:45 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
def test_lite_optimizer_steps():
|
|
|
|
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
|
|
|
|
optimizer = Mock()
|
2021-11-30 08:31:23 +00:00
|
|
|
strategy = Mock()
|
2022-02-09 09:37:13 +00:00
|
|
|
strategy.optimizer_step.return_value = 123
|
2021-11-30 14:16:59 +00:00
|
|
|
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
|
2022-02-09 09:37:13 +00:00
|
|
|
step_output = lite_optimizer.step()
|
|
|
|
assert step_output == 123
|
2021-11-30 08:31:23 +00:00
|
|
|
strategy.optimizer_step.assert_called_once()
|
2021-11-30 14:16:59 +00:00
|
|
|
strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)
|