2021-10-29 21:46:39 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-01-09 14:01:11 +00:00
|
|
|
from unittest import mock
|
2022-12-19 21:57:15 +00:00
|
|
|
from unittest.mock import call, Mock
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2023-01-04 15:57:18 +00:00
|
|
|
from tests_fabric.helpers.runif import RunIf
|
2023-02-23 00:11:29 +00:00
|
|
|
from torch.utils.data import BatchSampler, DistributedSampler
|
2021-11-01 18:33:13 +00:00
|
|
|
from torch.utils.data.dataloader import DataLoader
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.fabric import Fabric
|
|
|
|
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
|
|
|
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class EmptyFabric(Fabric):
|
2021-10-29 21:46:39 +00:00
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_module_wraps():
|
2021-10-29 21:46:39 +00:00
|
|
|
"""Test that the wrapped module is accessible via the property."""
|
|
|
|
module = Mock()
|
2023-01-04 15:57:18 +00:00
|
|
|
assert _FabricModule(module, Mock()).module is module
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2022-05-11 18:28:08 +00:00
|
|
|
wrapped_module = Mock()
|
|
|
|
original_module = Mock()
|
2023-01-04 15:57:18 +00:00
|
|
|
assert _FabricModule(wrapped_module, Mock(), original_module=original_module).module is original_module
|
2022-05-11 18:28:08 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_module_attribute_lookup():
|
2022-09-14 23:29:23 +00:00
|
|
|
"""Test that attribute lookup passes through to the original module when possible."""
|
2022-05-11 18:28:08 +00:00
|
|
|
|
|
|
|
class OriginalModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer = torch.nn.Linear(2, 3)
|
|
|
|
self.attribute = 1
|
|
|
|
|
|
|
|
def method(self):
|
|
|
|
return 2
|
|
|
|
|
|
|
|
original_module = OriginalModule()
|
|
|
|
|
|
|
|
class ModuleWrapper(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.wrapped = original_module
|
|
|
|
|
|
|
|
wrapped_module = ModuleWrapper()
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
|
|
|
|
assert fabric_module.attribute == 1
|
|
|
|
assert fabric_module.layer is original_module.layer
|
|
|
|
assert fabric_module.method() == 2
|
|
|
|
assert fabric_module.forward.__self__.__class__ == _FabricModule
|
2022-05-11 18:28:08 +00:00
|
|
|
|
|
|
|
with pytest.raises(AttributeError):
|
2023-01-10 15:02:05 +00:00
|
|
|
_ = fabric_module.not_exists
|
2022-05-11 18:28:08 +00:00
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_module_state_dict_access():
|
2022-09-14 23:29:23 +00:00
|
|
|
"""Test that state_dict access passes through to the original module."""
|
|
|
|
|
|
|
|
class OriginalModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer = torch.nn.Linear(2, 3)
|
|
|
|
|
|
|
|
original_module = OriginalModule()
|
|
|
|
|
|
|
|
class ModuleWrapper(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.wrapped = original_module
|
|
|
|
|
|
|
|
wrapped_module = ModuleWrapper()
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
|
|
|
|
state_dict = fabric_module.state_dict()
|
2022-09-14 23:29:23 +00:00
|
|
|
assert set(state_dict.keys()) == {"layer.weight", "layer.bias"}
|
|
|
|
|
|
|
|
weight, bias = torch.rand(3, 2), torch.rand(3)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias})
|
|
|
|
assert torch.equal(fabric_module.layer.weight, weight)
|
|
|
|
assert torch.equal(fabric_module.layer.bias, bias)
|
2022-09-14 23:29:23 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
@pytest.mark.parametrize(
|
2022-06-24 12:15:48 +00:00
|
|
|
"precision, input_type, expected_type, accelerator, device_str",
|
2021-10-29 21:46:39 +00:00
|
|
|
[
|
2023-03-02 23:41:37 +00:00
|
|
|
pytest.param(32, torch.float16, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
2022-06-24 12:15:48 +00:00
|
|
|
pytest.param(32, torch.float32, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
2023-03-02 23:41:37 +00:00
|
|
|
pytest.param(32, torch.float64, torch.float64, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
2022-06-24 12:15:48 +00:00
|
|
|
pytest.param(32, torch.int, torch.int, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param(16, torch.float32, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param(16, torch.float64, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param(16, torch.long, torch.long, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param(
|
|
|
|
"bf16",
|
|
|
|
torch.float32,
|
|
|
|
torch.bfloat16,
|
|
|
|
"gpu",
|
|
|
|
"cuda:0",
|
2022-11-10 13:59:13 +00:00
|
|
|
marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
|
2022-06-24 12:15:48 +00:00
|
|
|
),
|
|
|
|
pytest.param(
|
|
|
|
"bf16",
|
|
|
|
torch.float64,
|
|
|
|
torch.bfloat16,
|
|
|
|
"gpu",
|
|
|
|
"cuda:0",
|
2022-11-10 13:59:13 +00:00
|
|
|
marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
|
2022-06-24 12:15:48 +00:00
|
|
|
),
|
|
|
|
pytest.param(
|
|
|
|
"bf16",
|
|
|
|
torch.bool,
|
|
|
|
torch.bool,
|
|
|
|
"gpu",
|
|
|
|
"cuda:0",
|
2022-11-10 13:59:13 +00:00
|
|
|
marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
|
2022-06-24 12:15:48 +00:00
|
|
|
),
|
|
|
|
pytest.param(32, torch.float32, torch.float32, "mps", "mps:0", marks=RunIf(mps=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str):
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricModule performs autocasting on the input tensors and during forward()."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(precision=precision, accelerator=accelerator, devices=1)
|
2022-06-24 12:15:48 +00:00
|
|
|
device = torch.device(device_str)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
def check_autocast(forward_input):
|
|
|
|
assert precision != 16 or torch.is_autocast_enabled()
|
|
|
|
return forward_input
|
|
|
|
|
2021-11-09 14:21:00 +00:00
|
|
|
module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module = _FabricModule(module, fabric._precision).to(device)
|
|
|
|
out = fabric_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
|
2021-10-29 21:46:39 +00:00
|
|
|
assert module.call_args[0][0].dtype == expected_type
|
2021-11-09 14:21:00 +00:00
|
|
|
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2021-11-16 17:26:46 +00:00
|
|
|
@pytest.mark.parametrize(
|
2022-06-24 12:15:48 +00:00
|
|
|
"device_str",
|
|
|
|
[
|
|
|
|
"cpu",
|
|
|
|
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param("mps", marks=RunIf(mps=True)),
|
|
|
|
],
|
2021-11-16 17:26:46 +00:00
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_module_device_dtype_propagation(device_str, dtype):
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
|
2021-11-16 17:26:46 +00:00
|
|
|
|
2022-06-24 12:15:48 +00:00
|
|
|
device = torch.device(device_str)
|
|
|
|
|
2022-09-06 12:17:15 +00:00
|
|
|
class DeviceModule(_DeviceDtypeModuleMixin):
|
2021-11-16 17:26:46 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
device_module = DeviceModule()
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module = _FabricModule(device_module, Mock())
|
|
|
|
fabric_module.to(device)
|
2021-11-16 17:26:46 +00:00
|
|
|
assert device_module.device == device
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_module.device == device
|
2021-11-16 17:26:46 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_module.to(dtype)
|
2021-11-16 17:26:46 +00:00
|
|
|
assert device_module.dtype == dtype
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_module.dtype == dtype
|
2021-11-16 17:26:46 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_dataloader_iterator():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no
|
|
|
|
automatic device placement)."""
|
2021-11-02 10:40:35 +00:00
|
|
|
dataloader = DataLoader(range(5), batch_size=2)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = _FabricDataLoader(dataloader)
|
|
|
|
assert len(fabric_dataloader) == len(dataloader) == 3
|
2021-11-02 10:40:35 +00:00
|
|
|
|
|
|
|
iterator = iter(dataloader)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_iterator = iter(fabric_dataloader)
|
2021-11-02 10:40:35 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
assert torch.equal(next(iterator), next(fabric_iterator))
|
|
|
|
assert torch.equal(next(iterator), next(fabric_iterator))
|
|
|
|
assert torch.equal(next(iterator), next(fabric_iterator))
|
2021-11-02 10:40:35 +00:00
|
|
|
|
|
|
|
with pytest.raises(StopIteration):
|
|
|
|
next(iterator)
|
|
|
|
|
|
|
|
with pytest.raises(StopIteration):
|
2023-01-10 15:02:05 +00:00
|
|
|
next(fabric_iterator)
|
2021-11-02 10:40:35 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
@pytest.mark.parametrize(
|
2022-06-24 12:15:48 +00:00
|
|
|
"src_device_str, dest_device_str",
|
2021-10-29 21:46:39 +00:00
|
|
|
[
|
2022-06-24 12:15:48 +00:00
|
|
|
("cpu", "cpu"),
|
|
|
|
pytest.param("cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param("cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)),
|
2022-09-16 19:21:36 +00:00
|
|
|
# pytest.param("cpu", "mps", marks=RunIf(mps=True)), # TODO: Add once torch.equal is supported
|
2022-06-24 12:15:48 +00:00
|
|
|
pytest.param("mps", "cpu", marks=RunIf(mps=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricDataLoader moves data to the device in its iterator."""
|
2022-06-24 12:15:48 +00:00
|
|
|
src_device = torch.device(src_device_str)
|
|
|
|
dest_device = torch.device(dest_device_str)
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
sample0 = torch.tensor(0, device=src_device)
|
|
|
|
sample1 = torch.tensor(1, device=src_device)
|
|
|
|
sample2 = {"data": torch.tensor(2, device=src_device)}
|
|
|
|
sample3 = {"data": torch.tensor(3, device=src_device)}
|
2021-11-01 18:33:13 +00:00
|
|
|
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = _FabricDataLoader(dataloader=dataloader, device=dest_device)
|
|
|
|
iterator = iter(fabric_dataloader)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
batch0 = next(iterator)
|
2022-09-16 19:21:36 +00:00
|
|
|
# TODO: torch.equal is not supported on MPS at this time (torch 1.12)
|
|
|
|
assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device))
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
batch1 = next(iterator)
|
2022-09-16 19:21:36 +00:00
|
|
|
# TODO: torch.equal is not supported on MPS at this time (torch 1.12)
|
|
|
|
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-02-23 00:11:29 +00:00
|
|
|
@pytest.mark.parametrize("use_batch_sampler", (False, True))
|
|
|
|
def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler):
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
|
2023-02-23 00:11:29 +00:00
|
|
|
dataset = range(3)
|
|
|
|
sampler = DistributedSampler(dataset, num_replicas=2, rank=0)
|
2022-12-19 21:57:15 +00:00
|
|
|
sampler.set_epoch = Mock()
|
2023-02-23 00:11:29 +00:00
|
|
|
|
|
|
|
if not use_batch_sampler:
|
|
|
|
dataloader = DataLoader(dataset, sampler=sampler)
|
|
|
|
else:
|
|
|
|
batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False)
|
|
|
|
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = _FabricDataLoader(dataloader)
|
|
|
|
iterator_epoch_0 = iter(fabric_dataloader)
|
2023-02-23 00:11:29 +00:00
|
|
|
sampler.set_epoch.assert_not_called()
|
|
|
|
|
2022-12-19 21:57:15 +00:00
|
|
|
next(iterator_epoch_0)
|
|
|
|
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
|
2023-02-23 00:11:29 +00:00
|
|
|
assert sampler.set_epoch.mock_calls == [call(0)]
|
|
|
|
|
2022-12-19 21:57:15 +00:00
|
|
|
next(iterator_epoch_0)
|
2023-02-23 00:11:29 +00:00
|
|
|
assert sampler.set_epoch.mock_calls == [call(0)]
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
iterator_epoch_1 = iter(fabric_dataloader)
|
2023-02-23 00:11:29 +00:00
|
|
|
assert sampler.set_epoch.mock_calls == [call(0)]
|
|
|
|
|
2022-12-19 21:57:15 +00:00
|
|
|
next(iterator_epoch_1)
|
|
|
|
# with every new iterator call, the epoch increases
|
2023-02-23 00:11:29 +00:00
|
|
|
assert sampler.set_epoch.mock_calls == [call(0), call(1)]
|
2022-12-19 21:57:15 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_optimizer_wraps():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricOptimizer fully wraps the optimizer."""
|
2021-10-29 21:46:39 +00:00
|
|
|
optimizer_cls = torch.optim.SGD
|
|
|
|
optimizer = Mock(spec=optimizer_cls)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer, Mock())
|
|
|
|
assert fabric_optimizer.optimizer is optimizer
|
|
|
|
assert isinstance(fabric_optimizer, optimizer_cls)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_optimizer_state_dict():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricOptimizer calls into the strategy to collect the state."""
|
2021-11-27 04:54:45 +00:00
|
|
|
optimizer = Mock()
|
2021-11-30 14:16:59 +00:00
|
|
|
strategy = Mock()
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
|
|
|
|
fabric_optimizer.state_dict()
|
2022-09-16 17:25:27 +00:00
|
|
|
strategy.get_optimizer_state.assert_called_with(optimizer)
|
2021-11-27 04:54:45 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_fabric_optimizer_steps():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
|
2021-10-29 21:46:39 +00:00
|
|
|
optimizer = Mock()
|
2022-09-21 17:28:45 +00:00
|
|
|
strategy = Mock(spec=["optimizer_step"])
|
2022-02-09 09:37:13 +00:00
|
|
|
strategy.optimizer_step.return_value = 123
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
|
|
|
|
step_output = fabric_optimizer.step()
|
2022-02-09 09:37:13 +00:00
|
|
|
assert step_output == 123
|
2022-09-21 17:28:45 +00:00
|
|
|
strategy.optimizer_step.assert_called_once_with(optimizer)
|
2022-09-16 17:25:27 +00:00
|
|
|
|
2022-09-21 17:28:45 +00:00
|
|
|
strategy.reset_mock()
|
2022-09-16 17:25:27 +00:00
|
|
|
|
|
|
|
# with closure as input
|
|
|
|
closure = Mock()
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer.step(closure=closure)
|
2022-09-21 17:28:45 +00:00
|
|
|
strategy.optimizer_step.assert_called_once_with(optimizer, closure=closure)
|
|
|
|
|
|
|
|
# with model as optimizer
|
|
|
|
strategy = Mock(spec=["optimizer_step", "model"])
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
|
|
|
|
fabric_optimizer.step()
|
2022-09-21 17:28:45 +00:00
|
|
|
strategy.optimizer_step.assert_called_once_with(strategy.model)
|
2023-01-09 14:01:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_fabric_optimizer_zero_grad_kwargs():
|
|
|
|
"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
|
|
|
|
|
|
|
|
# Test PyTorch's standard `.zero_grad()` signature
|
|
|
|
with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
|
|
|
|
optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
|
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
|
|
|
|
fabric_optimizer.zero_grad()
|
|
|
|
zero_grad_mock.assert_called_with()
|
|
|
|
fabric_optimizer.zero_grad(set_to_none=False)
|
|
|
|
zero_grad_mock.assert_called_with(set_to_none=False)
|
|
|
|
fabric_optimizer.zero_grad(set_to_none=True)
|
|
|
|
zero_grad_mock.assert_called_with(set_to_none=True)
|
|
|
|
|
|
|
|
# Test weird `.zero_grad()` signatures from other libraries
|
|
|
|
custom_zero_grad = Mock()
|
|
|
|
|
|
|
|
class CustomSGD(torch.optim.SGD):
|
|
|
|
def zero_grad(self, set_grads_to_None=False):
|
|
|
|
custom_zero_grad(set_grads_to_None=set_grads_to_None)
|
|
|
|
|
|
|
|
optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
|
|
|
|
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
|
|
|
|
fabric_optimizer.zero_grad()
|
|
|
|
custom_zero_grad.assert_called_with(set_grads_to_None=False)
|
|
|
|
fabric_optimizer.zero_grad(set_to_none=False)
|
|
|
|
custom_zero_grad.assert_called_with(set_grads_to_None=False)
|
|
|
|
fabric_optimizer.zero_grad(set_to_none=True)
|
|
|
|
custom_zero_grad.assert_called_with(set_grads_to_None=True)
|