2021-10-29 21:46:39 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import os
|
2022-11-02 14:56:22 +00:00
|
|
|
from re import escape
|
2021-10-29 21:46:39 +00:00
|
|
|
from unittest import mock
|
2022-05-11 18:28:08 +00:00
|
|
|
from unittest.mock import ANY, MagicMock, Mock, PropertyMock
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
import torch.nn.functional
|
2022-12-13 13:13:51 +00:00
|
|
|
from lightning_utilities.test.warning import no_warning_call
|
2023-01-04 15:57:18 +00:00
|
|
|
from tests_fabric.helpers.runif import RunIf
|
2021-10-29 21:46:39 +00:00
|
|
|
from torch import nn
|
2022-12-08 12:50:52 +00:00
|
|
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-04 15:57:18 +00:00
|
|
|
from lightning_fabric.fabric import Fabric
|
|
|
|
from lightning_fabric.plugins import Precision
|
|
|
|
from lightning_fabric.strategies import (
|
2022-11-11 13:36:59 +00:00
|
|
|
DDPStrategy,
|
|
|
|
DeepSpeedStrategy,
|
|
|
|
ParallelStrategy,
|
|
|
|
SingleDeviceStrategy,
|
|
|
|
Strategy,
|
|
|
|
XLAStrategy,
|
|
|
|
)
|
2023-01-04 15:57:18 +00:00
|
|
|
from lightning_fabric.strategies.strategy import _Sharded
|
|
|
|
from lightning_fabric.utilities.exceptions import MisconfigurationException
|
|
|
|
from lightning_fabric.utilities.seed import pl_worker_init_function, seed_everything
|
|
|
|
from lightning_fabric.utilities.warnings import PossibleUserWarning
|
|
|
|
from lightning_fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class EmptyFabric(Fabric):
|
2021-10-29 21:46:39 +00:00
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class BoringModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer = torch.nn.Linear(32, 2, bias=False)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.layer(x)
|
|
|
|
return torch.nn.functional.mse_loss(x, torch.ones_like(x))
|
|
|
|
|
|
|
|
|
|
|
|
def test_run_input_output():
|
|
|
|
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class RunFabric(Fabric):
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
run_args = ()
|
|
|
|
run_kwargs = {}
|
|
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
|
self.run_args = args
|
|
|
|
self.run_kwargs = kwargs
|
|
|
|
return "result"
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = RunFabric()
|
|
|
|
result = fabric.run(1, 2, three=3)
|
2021-10-29 21:46:39 +00:00
|
|
|
assert result == "result"
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric.run_args == (1, 2)
|
|
|
|
assert fabric.run_kwargs == {"three": 3}
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-01-04 15:57:18 +00:00
|
|
|
@mock.patch("lightning_fabric.strategies.ddp.DistributedDataParallel")
|
2022-11-11 13:36:59 +00:00
|
|
|
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
|
|
|
def test_setup_module(ddp_mock, setup_method):
|
2022-05-11 18:28:08 +00:00
|
|
|
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(accelerator="cpu", strategy="ddp", devices=2)
|
2022-05-11 18:28:08 +00:00
|
|
|
model = nn.Linear(1, 2)
|
2023-01-10 15:02:05 +00:00
|
|
|
setup_method = getattr(fabric, setup_method)
|
|
|
|
fabric_model = setup_method(model)
|
2022-05-11 18:28:08 +00:00
|
|
|
ddp_mock.assert_called_with(module=model, device_ids=ANY)
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_model.module == model
|
|
|
|
assert fabric_model.weight is model.weight
|
|
|
|
assert fabric_model.forward != model.forward
|
2022-05-11 18:28:08 +00:00
|
|
|
|
|
|
|
|
2022-09-21 19:06:10 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"accelerator, initial_device, target_device",
|
|
|
|
[
|
|
|
|
("cpu", "cpu", "cpu"),
|
|
|
|
pytest.param("cpu", "cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param("cpu", "mps:0", "cpu", marks=RunIf(mps=True)),
|
|
|
|
pytest.param("cuda", "cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
|
|
|
pytest.param("cuda", "cuda:1", "cuda:0", marks=RunIf(min_cuda_gpus=2)),
|
|
|
|
pytest.param("mps", "cpu", "mps:0", marks=RunIf(mps=True)),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("move_to_device", [True, False])
|
2022-11-11 13:36:59 +00:00
|
|
|
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
|
|
|
def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device):
|
2022-09-21 19:06:10 +00:00
|
|
|
"""Test that `move_to_device` leads to parameters being moved to the correct device and that the device
|
|
|
|
attributes on the wrapper are updated."""
|
|
|
|
initial_device = torch.device(initial_device)
|
|
|
|
target_device = torch.device(target_device)
|
|
|
|
expected_device = target_device if move_to_device else initial_device
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(accelerator=accelerator, devices=1)
|
2022-09-21 19:06:10 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
model.to(initial_device)
|
2023-01-10 15:02:05 +00:00
|
|
|
setup_method = getattr(fabric, setup_method)
|
|
|
|
fabric_model = setup_method(model, move_to_device=move_to_device)
|
2022-09-21 19:06:10 +00:00
|
|
|
|
|
|
|
# all parameters on the expected device
|
|
|
|
assert all(param.device == expected_device for param in model.parameters())
|
2023-01-10 15:02:05 +00:00
|
|
|
assert all(param.device == expected_device for param in fabric_model.parameters())
|
2022-09-21 19:06:10 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_model.device == expected_device
|
|
|
|
assert fabric.device == target_device
|
2022-09-21 19:06:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1)
|
|
|
|
@pytest.mark.parametrize("move_to_device", [True, False])
|
2022-11-11 13:36:59 +00:00
|
|
|
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
|
|
|
def test_setup_module_parameters_on_different_devices(setup_method, move_to_device):
|
2022-09-21 19:06:10 +00:00
|
|
|
"""Test that a warning is emitted when model parameters are on a different device prior to calling
|
|
|
|
`setup()`."""
|
|
|
|
device0 = torch.device("cpu")
|
|
|
|
device1 = torch.device("cuda", 0)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(accelerator="cuda", devices=1)
|
2022-09-21 19:06:10 +00:00
|
|
|
|
|
|
|
module0 = nn.Linear(1, 2).to(device0)
|
|
|
|
module1 = nn.Linear(1, 2).to(device1)
|
|
|
|
model = nn.Sequential(module0, module1)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
setup_method = getattr(fabric, setup_method)
|
2022-11-11 13:36:59 +00:00
|
|
|
|
2022-09-21 19:06:10 +00:00
|
|
|
if move_to_device:
|
|
|
|
with pytest.warns(PossibleUserWarning, match="has parameters on different devices"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model = setup_method(model, move_to_device=move_to_device)
|
2022-09-21 19:06:10 +00:00
|
|
|
|
|
|
|
# both have the same device now
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_model.device == device1
|
2022-09-21 19:06:10 +00:00
|
|
|
assert module0.weight.device == module0.bias.device == device1
|
|
|
|
assert module1.weight.device == module1.bias.device == device1
|
|
|
|
else:
|
|
|
|
with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"):
|
2022-11-11 13:36:59 +00:00
|
|
|
setup_method(model, move_to_device=move_to_device)
|
2022-09-21 19:06:10 +00:00
|
|
|
|
|
|
|
|
2022-11-11 13:36:59 +00:00
|
|
|
def test_setup_module_and_optimizers():
|
|
|
|
"""Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2021-10-29 21:46:39 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
|
|
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
|
|
|
|
|
|
# no optimizer
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model = fabric.setup(model)
|
|
|
|
assert isinstance(fabric_model, _FabricModule)
|
|
|
|
assert fabric_model.module is model
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
# single optimizer
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model, fabric_optimizer = fabric.setup(model, optimizer0)
|
|
|
|
assert isinstance(fabric_model, _FabricModule)
|
|
|
|
assert isinstance(fabric_optimizer, _FabricOptimizer)
|
|
|
|
assert fabric_model.module is model
|
|
|
|
assert fabric_optimizer.optimizer is optimizer0
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
# multiple optimizers
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model, fabric_optimizer0, fabric_optimizer1 = fabric.setup(model, optimizer0, optimizer1)
|
|
|
|
assert isinstance(fabric_model, _FabricModule)
|
|
|
|
assert isinstance(fabric_optimizer0, _FabricOptimizer)
|
|
|
|
assert isinstance(fabric_optimizer1, _FabricOptimizer)
|
|
|
|
assert fabric_model.module is model
|
|
|
|
assert fabric_optimizer0.optimizer is optimizer0
|
|
|
|
assert fabric_optimizer1.optimizer is optimizer1
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2022-11-11 13:36:59 +00:00
|
|
|
def test_setup_optimizers():
|
|
|
|
"""Test that `setup_optimizers()` can handle one or more optimizers."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2022-11-11 13:36:59 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
|
|
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
|
|
|
|
|
|
# single optimizer
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = fabric.setup_optimizers(optimizer0)
|
|
|
|
assert isinstance(fabric_optimizer, _FabricOptimizer)
|
|
|
|
assert fabric_optimizer.optimizer is optimizer0
|
2022-11-11 13:36:59 +00:00
|
|
|
|
|
|
|
# multiple optimizers
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer0, fabric_optimizer1 = fabric.setup_optimizers(optimizer0, optimizer1)
|
|
|
|
assert isinstance(fabric_optimizer0, _FabricOptimizer)
|
|
|
|
assert isinstance(fabric_optimizer1, _FabricOptimizer)
|
|
|
|
assert fabric_optimizer0.optimizer is optimizer0
|
|
|
|
assert fabric_optimizer1.optimizer is optimizer1
|
2022-11-11 13:36:59 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
def test_setup_twice_fails():
|
2022-11-11 13:36:59 +00:00
|
|
|
"""Test that calling `setup` with a model or optimizer that is already wrapped fails."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2021-10-29 21:46:39 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(ValueError, match="A model should be passed only once to the"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup(fabric_model, optimizer)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(ValueError, match="An optimizer should be passed only once to the"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup(model, fabric_optimizer)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2022-11-11 13:36:59 +00:00
|
|
|
def test_setup_module_twice_fails():
|
|
|
|
"""Test that calling `setup_module` with a model that is already wrapped fails."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2022-11-11 13:36:59 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_model = fabric.setup_module(model)
|
2022-11-11 13:36:59 +00:00
|
|
|
with pytest.raises(ValueError, match="A model should be passed only once to the"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_module(fabric_model)
|
2022-11-11 13:36:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_setup_optimizers_twice_fails():
|
|
|
|
"""Test that calling `setup_module` with a model that is already wrapped fails."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2022-11-11 13:36:59 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_optimizer = fabric.setup_optimizers(optimizer)
|
2022-11-11 13:36:59 +00:00
|
|
|
with pytest.raises(ValueError, match="An optimizer should be passed only once to"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_optimizers(fabric_optimizer)
|
2022-11-11 13:36:59 +00:00
|
|
|
|
|
|
|
|
2023-01-11 17:08:18 +00:00
|
|
|
@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, XLAStrategy])
|
2022-11-11 13:36:59 +00:00
|
|
|
def test_setup_optimizers_not_supported(strategy_cls):
|
|
|
|
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
|
|
|
|
independently."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2022-11-11 13:36:59 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters())
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy = Mock(spec=strategy_cls)
|
2022-11-11 13:36:59 +00:00
|
|
|
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_optimizers(optimizer)
|
2022-11-11 13:36:59 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
def test_setup_tracks_num_models():
|
|
|
|
"""Test that setup() tracks how many times it has setup a model."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2021-10-29 21:46:39 +00:00
|
|
|
model = nn.Linear(1, 2)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric._models_setup == 0
|
|
|
|
fabric.setup(model, optimizer)
|
|
|
|
assert fabric._models_setup == 1
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup(model, optimizer)
|
|
|
|
assert fabric._models_setup == 2
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_module(model)
|
|
|
|
assert fabric._models_setup == 3
|
2022-11-11 13:36:59 +00:00
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2022-11-11 13:36:59 +00:00
|
|
|
def test_setup_dataloaders_unsupported_input():
|
2021-10-29 21:46:39 +00:00
|
|
|
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2022-11-11 13:36:59 +00:00
|
|
|
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_dataloaders()
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_dataloaders(range(2)) # type: ignore
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_setup_dataloaders_return_type():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the setup method returns the dataloaders wrapped as FabricDataLoader and in the right order."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
# single dataloader
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(2)))
|
|
|
|
assert isinstance(fabric_dataloader, _FabricDataLoader)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
# multiple dataloaders
|
|
|
|
dataset0 = Mock()
|
|
|
|
dataset1 = Mock()
|
|
|
|
dataloader0 = DataLoader(dataset0)
|
|
|
|
dataloader1 = DataLoader(dataset1)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader0, fabric_dataloader1 = fabric.setup_dataloaders(dataloader0, dataloader1)
|
|
|
|
assert isinstance(fabric_dataloader0, _FabricDataLoader)
|
|
|
|
assert isinstance(fabric_dataloader1, _FabricDataLoader)
|
|
|
|
assert fabric_dataloader0.dataset is dataset0
|
|
|
|
assert fabric_dataloader1.dataset is dataset1
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2023-01-04 15:57:18 +00:00
|
|
|
@mock.patch("lightning_fabric.fabric._replace_dunder_methods")
|
2021-11-24 14:58:51 +00:00
|
|
|
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run
|
|
|
|
method."""
|
2021-11-24 14:58:51 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class RunFabric(Fabric):
|
2021-11-05 17:31:45 +00:00
|
|
|
def run(self):
|
2022-07-27 15:32:50 +00:00
|
|
|
# One for BatchSampler, another for DataLoader
|
|
|
|
assert ctx_manager().__enter__.call_count == 2
|
2021-11-24 14:58:51 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
RunFabric().run()
|
2022-07-27 15:32:50 +00:00
|
|
|
assert ctx_manager().__exit__.call_count == 2
|
2021-11-05 17:31:45 +00:00
|
|
|
|
|
|
|
|
2021-11-23 15:35:07 +00:00
|
|
|
def test_setup_dataloaders_raises_for_unknown_custom_args():
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's
|
|
|
|
run method."""
|
|
|
|
fabric = EmptyFabric()
|
2021-11-23 15:35:07 +00:00
|
|
|
|
|
|
|
class CustomDataLoader(DataLoader):
|
|
|
|
def __init__(self, new_arg, *args, **kwargs):
|
|
|
|
super().__init__(range(5), *args, **kwargs)
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
MisconfigurationException,
|
|
|
|
match=(
|
2022-06-21 23:53:24 +00:00
|
|
|
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
|
2021-11-23 15:35:07 +00:00
|
|
|
r"The missing attributes are \['new_arg'\]"
|
|
|
|
),
|
|
|
|
):
|
|
|
|
# The dataloader was not created within the run function, and therefore init args were not intercepted
|
|
|
|
dataloader = CustomDataLoader(2, batch_size=2)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_dataloaders(dataloader)
|
2021-11-23 15:35:07 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
def test_setup_dataloaders_twice_fails():
|
|
|
|
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
2021-10-29 21:46:39 +00:00
|
|
|
dataloader = DataLoader(range(2))
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = fabric.setup_dataloaders(dataloader)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(ValueError, match="A dataloader should be passed only once to the"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_dataloaders(fabric_dataloader)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock.patch(
|
2023-01-04 15:57:18 +00:00
|
|
|
"lightning_fabric.fabric.Fabric.device",
|
2021-10-29 21:46:39 +00:00
|
|
|
new_callable=PropertyMock,
|
|
|
|
return_value=torch.device("cuda", 1),
|
|
|
|
)
|
2023-01-10 15:02:05 +00:00
|
|
|
def test_setup_dataloaders_move_to_device(fabric_device_mock):
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that the setup configures FabricDataLoader to move the data to the device automatically."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False)
|
|
|
|
assert all(dl.device is None for dl in fabric_dataloaders)
|
|
|
|
fabric_device_mock.assert_not_called()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True)
|
|
|
|
assert all(dl.device == torch.device("cuda", 1) for dl in fabric_dataloaders)
|
|
|
|
fabric_device_mock.assert_called()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_setup_dataloaders_distributed_sampler_not_needed():
|
|
|
|
"""Test that replace_sampler option has no effect when no distributed sampler is needed."""
|
|
|
|
custom_sampler = Mock(spec=Sampler)
|
|
|
|
dataloader = DataLoader(Mock(), sampler=custom_sampler)
|
|
|
|
|
|
|
|
# keep the custom sampler when not needed to replace
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric_dataloader = fabric.setup_dataloaders(dataloader, replace_sampler=True)
|
|
|
|
assert fabric_dataloader.sampler is custom_sampler
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
2022-12-08 12:50:52 +00:00
|
|
|
def test_setup_dataloaders_distributed_sampler_shuffle():
|
|
|
|
"""Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed
|
|
|
|
sampler."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric(accelerator="cpu", strategy="ddp_spawn", devices=2)
|
|
|
|
# no fabric.launch(): pretend we are on rank 0 now
|
2022-12-08 12:50:52 +00:00
|
|
|
|
|
|
|
dataset = TensorDataset(torch.arange(8))
|
|
|
|
|
|
|
|
# shuffling turned off
|
|
|
|
no_shuffle_dataloaders = [
|
|
|
|
DataLoader(dataset),
|
|
|
|
DataLoader(dataset, shuffle=False),
|
|
|
|
DataLoader(dataset, sampler=SequentialSampler(dataset)),
|
|
|
|
]
|
|
|
|
for dataloader in no_shuffle_dataloaders:
|
2023-01-10 15:02:05 +00:00
|
|
|
dataloader = fabric.setup_dataloaders(dataloader)
|
2022-12-08 12:50:52 +00:00
|
|
|
assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6]
|
|
|
|
|
|
|
|
# shuffling turned on
|
|
|
|
shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))]
|
|
|
|
for dataloader in shuffle_dataloaders:
|
|
|
|
seed_everything(1)
|
2023-01-10 15:02:05 +00:00
|
|
|
dataloader = fabric.setup_dataloaders(dataloader)
|
2022-12-19 21:57:15 +00:00
|
|
|
assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", [True, False])
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3])
|
|
|
|
def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size):
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that the distributed sampler setup in Fabric leads to the same sequence of data as in raw PyTorch."""
|
2022-12-19 21:57:15 +00:00
|
|
|
torch.manual_seed(1)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric(accelerator="cpu", strategy="ddp", devices=2)
|
|
|
|
# no fabric.launch(): pretend we are on rank 0 now
|
2022-12-19 21:57:15 +00:00
|
|
|
|
|
|
|
dataset = torch.arange(10)
|
|
|
|
torch_dataloader = DataLoader(
|
|
|
|
dataset,
|
|
|
|
sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle),
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
|
|
|
|
fabric_dataloader = fabric.setup_dataloaders(fabric_dataloader)
|
2022-12-19 21:57:15 +00:00
|
|
|
|
|
|
|
def fetch_epoch(loader):
|
|
|
|
iterator = iter(loader)
|
|
|
|
# we fetch 2 batches per epoch
|
|
|
|
return torch.cat((next(iterator), next(iterator)))
|
|
|
|
|
|
|
|
# 1st epoch
|
2023-01-10 15:02:05 +00:00
|
|
|
# PyTorch users needs to set the epoch, while in Fabric it gets handled automatically
|
2022-12-19 21:57:15 +00:00
|
|
|
torch_dataloader.sampler.set_epoch(0)
|
|
|
|
torch_data = fetch_epoch(torch_dataloader)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_data = fetch_epoch(fabric_dataloader)
|
|
|
|
assert torch.equal(torch_data, fabric_data)
|
2022-12-19 21:57:15 +00:00
|
|
|
|
|
|
|
# 2nd epoch
|
2023-01-10 15:02:05 +00:00
|
|
|
# PyTorch users needs to set the epoch, while in Fabric it gets handled automatically
|
2022-12-19 21:57:15 +00:00
|
|
|
torch_dataloader.sampler.set_epoch(1)
|
|
|
|
torch_data = fetch_epoch(torch_dataloader)
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_data = fetch_epoch(fabric_dataloader)
|
|
|
|
assert torch.equal(torch_data, fabric_data)
|
2022-12-19 21:57:15 +00:00
|
|
|
assert torch_dataloader.sampler.epoch == 1
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_dataloader._dataloader.sampler.epoch == 1
|
2022-12-08 12:50:52 +00:00
|
|
|
|
|
|
|
|
2021-10-29 21:46:39 +00:00
|
|
|
@mock.patch.dict(os.environ, {}, clear=True)
|
|
|
|
def test_seed_everything():
|
|
|
|
"""Test that seed everything is static and sets the worker init function on the dataloader."""
|
2023-01-10 15:02:05 +00:00
|
|
|
EmptyFabric.seed_everything(3)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric_dataloader = fabric.setup_dataloaders(DataLoader(Mock()))
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
assert fabric_dataloader.worker_init_fn.func is pl_worker_init_function
|
2021-10-29 21:46:39 +00:00
|
|
|
assert os.environ == {"PL_GLOBAL_SEED": "3", "PL_SEED_WORKERS": "1"}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"strategy",
|
|
|
|
[
|
2023-01-10 22:05:12 +00:00
|
|
|
"dp",
|
|
|
|
"ddp",
|
|
|
|
"ddp_spawn",
|
|
|
|
pytest.param("ddp_fork", marks=RunIf(skip_windows=True)),
|
|
|
|
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_setup_dataloaders_replace_custom_sampler(strategy):
|
|
|
|
"""Test that asking to replace a custom sampler results in an error when a distributed sampler would be
|
|
|
|
needed."""
|
|
|
|
custom_sampler = Mock(spec=Sampler)
|
|
|
|
dataloader = DataLoader(Mock(), sampler=custom_sampler)
|
|
|
|
|
|
|
|
# explicitly asking to replace when a custom sampler is already configured raises an exception
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
|
2023-01-11 15:29:51 +00:00
|
|
|
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup_dataloaders(dataloader, replace_sampler=True)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
# setting `replace_sampler=False` leaves the sampler untouched
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = fabric.setup_dataloaders(dataloader, replace_sampler=False)
|
|
|
|
assert fabric_dataloader.sampler is custom_sampler
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"strategy",
|
|
|
|
[
|
2023-01-10 22:05:12 +00:00
|
|
|
"dp",
|
|
|
|
"ddp",
|
|
|
|
"ddp_spawn",
|
|
|
|
pytest.param("ddp_fork", marks=RunIf(skip_windows=True)),
|
|
|
|
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("shuffle", [True, False])
|
|
|
|
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
|
|
|
|
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
|
2023-01-11 15:29:51 +00:00
|
|
|
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
|
|
|
|
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"accelerator, expected",
|
|
|
|
[
|
2022-06-24 12:15:48 +00:00
|
|
|
("cpu", "cpu"),
|
2022-07-25 14:46:45 +00:00
|
|
|
pytest.param("cuda", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
2022-06-24 12:15:48 +00:00
|
|
|
pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
|
2022-07-27 15:40:40 +00:00
|
|
|
pytest.param("tpu", "xla:0", marks=RunIf(tpu=True, standalone=True)),
|
2022-06-24 12:15:48 +00:00
|
|
|
pytest.param("mps", "mps:0", marks=RunIf(mps=True)),
|
2022-07-25 14:46:45 +00:00
|
|
|
pytest.param("gpu", "mps:0", marks=RunIf(mps=True)),
|
2021-10-29 21:46:39 +00:00
|
|
|
],
|
|
|
|
)
|
2022-10-04 22:54:14 +00:00
|
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
2021-10-29 21:46:39 +00:00
|
|
|
def test_to_device(accelerator, expected):
|
|
|
|
"""Test that the to_device method can move various objects to the device determined by the accelerator."""
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class RunFabric(Fabric):
|
2022-07-27 15:40:40 +00:00
|
|
|
def run(self):
|
|
|
|
expected_device = torch.device(expected)
|
|
|
|
|
|
|
|
# module
|
|
|
|
module = torch.nn.Linear(2, 3)
|
2023-01-10 15:02:05 +00:00
|
|
|
module = fabric.to_device(module)
|
2022-07-27 15:40:40 +00:00
|
|
|
assert all(param.device == expected_device for param in module.parameters())
|
2022-06-24 12:15:48 +00:00
|
|
|
|
2022-07-27 15:40:40 +00:00
|
|
|
# tensor
|
|
|
|
tensor = torch.rand(2, 2)
|
2023-01-10 15:02:05 +00:00
|
|
|
tensor = fabric.to_device(tensor)
|
2022-07-27 15:40:40 +00:00
|
|
|
assert tensor.device == expected_device
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2022-07-27 15:40:40 +00:00
|
|
|
# collection
|
|
|
|
collection = {"data": torch.rand(2, 2), "int": 1}
|
2023-01-10 15:02:05 +00:00
|
|
|
collection = fabric.to_device(collection)
|
2022-07-27 15:40:40 +00:00
|
|
|
assert collection["data"].device == expected_device
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = RunFabric(accelerator=accelerator, devices=1)
|
|
|
|
fabric.run()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_rank_properties():
|
|
|
|
"""Test that the rank properties are determined by the strategy."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric._strategy = Mock(spec=Strategy)
|
|
|
|
fabric._strategy.world_size = 1000
|
|
|
|
assert fabric.world_size == 1000
|
|
|
|
fabric._strategy.global_rank = 100
|
|
|
|
assert fabric.global_rank == 100
|
|
|
|
fabric._strategy.local_rank = 10
|
|
|
|
assert fabric.local_rank == 10
|
|
|
|
fabric._strategy.node_rank = 1
|
|
|
|
assert fabric.node_rank == 1
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_backward():
|
|
|
|
"""Test that backward() calls into the precision plugin."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric._precision = Mock(spec=Precision)
|
2021-10-29 21:46:39 +00:00
|
|
|
loss = Mock()
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.backward(loss, "arg", keyword="kwarg")
|
|
|
|
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(deepspeed=True)
|
|
|
|
def test_backward_model_input_required():
|
|
|
|
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric(strategy="deepspeed")
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
model0 = nn.Linear(1, 2)
|
|
|
|
model1 = nn.Linear(1, 2)
|
|
|
|
|
|
|
|
optimizer0 = torch.optim.Adam(model0.parameters())
|
|
|
|
optimizer1 = torch.optim.Adam(model1.parameters())
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy.setup_module_and_optimizers = lambda *args: args
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.setup(model0, optimizer0)
|
|
|
|
fabric.setup(model1, optimizer1)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
loss = model0(torch.randn(1, 1)).sum()
|
|
|
|
|
2022-09-16 17:25:27 +00:00
|
|
|
with pytest.raises(ValueError, match="please provide the model used to perform"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.backward(loss)
|
2021-10-29 21:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_autocast():
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that the Fabric autocast context manager lets the precision plugin handle casting."""
|
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric._precision.forward_context = MagicMock()
|
2021-10-29 21:46:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._precision.forward_context().__enter__.assert_not_called()
|
|
|
|
with fabric.autocast():
|
|
|
|
fabric._precision.forward_context().__enter__.assert_called()
|
|
|
|
fabric._precision.forward_context().__exit__.assert_called()
|
2022-10-19 19:55:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_no_backward_sync():
|
2023-01-10 15:02:05 +00:00
|
|
|
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
|
|
|
|
fabric = EmptyFabric()
|
2022-10-19 19:55:12 +00:00
|
|
|
model = nn.Linear(3, 3)
|
|
|
|
with pytest.raises(TypeError, match="You need to set up the model first"):
|
2023-01-10 15:02:05 +00:00
|
|
|
with fabric.no_backward_sync(model):
|
2022-10-19 19:55:12 +00:00
|
|
|
pass
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
model = fabric.setup(model)
|
2022-10-19 19:55:12 +00:00
|
|
|
|
|
|
|
# pretend that the strategy does not support skipping backward sync
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None)
|
2022-10-19 19:55:12 +00:00
|
|
|
with pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"):
|
2023-01-10 15:02:05 +00:00
|
|
|
with fabric.no_backward_sync(model):
|
2022-10-19 19:55:12 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
# for single-device strategies, it becomes a no-op without warning
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy = Mock(spec=SingleDeviceStrategy, _backward_sync_control=MagicMock())
|
|
|
|
with fabric.no_backward_sync(model):
|
2022-10-19 19:55:12 +00:00
|
|
|
pass
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
|
2022-10-19 19:55:12 +00:00
|
|
|
|
|
|
|
# pretend that the strategy supports skipping backward sync
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy = Mock(_backward_sync_control=MagicMock())
|
2022-10-19 19:55:12 +00:00
|
|
|
# disabling the context manager makes it a no-op
|
2023-01-10 15:02:05 +00:00
|
|
|
with fabric.no_backward_sync(model, enabled=False):
|
2022-10-19 19:55:12 +00:00
|
|
|
pass
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
|
2022-10-19 19:55:12 +00:00
|
|
|
# when enabld, the wrapped module gets passed down
|
2023-01-10 15:02:05 +00:00
|
|
|
with fabric.no_backward_sync(model):
|
2022-10-19 19:55:12 +00:00
|
|
|
pass
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_launch_without_function():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test the various ways `Fabric.launch()` can be called."""
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
# default: no launcher, single process
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric()
|
2023-01-04 15:57:18 +00:00
|
|
|
with mock.patch("lightning_fabric.fabric._do_nothing") as nothing:
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.launch()
|
2022-11-02 14:56:22 +00:00
|
|
|
nothing.assert_called()
|
|
|
|
|
|
|
|
# with a launcher on the strategy
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric()
|
|
|
|
fabric._strategy._launcher = Mock()
|
|
|
|
fabric.launch()
|
|
|
|
fabric._strategy._launcher.launch.assert_called()
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_launch_with_function():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test the various ways `Fabric.launch(function)` can be called."""
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
def fn_without_args():
|
|
|
|
pass
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric()
|
2022-11-02 14:56:22 +00:00
|
|
|
with pytest.raises(TypeError, match="The function passed to .* needs to take at least one argument"):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.launch(fn_without_args)
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
def fn_with_one_arg(arg):
|
2023-01-04 15:57:18 +00:00
|
|
|
assert isinstance(arg, Fabric)
|
2022-11-02 14:56:22 +00:00
|
|
|
fn_with_one_arg.called = True
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric()
|
|
|
|
fabric.launch(fn_with_one_arg)
|
2022-11-02 14:56:22 +00:00
|
|
|
assert fn_with_one_arg.called
|
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"LT_CLI_USED": "1"}) # pretend we are using the CLI
|
|
|
|
def test_launch_and_cli_not_allowed():
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = Fabric()
|
2022-11-02 14:56:22 +00:00
|
|
|
with pytest.raises(RuntimeError, match=escape("Calling `.launch()` again is not allowed")):
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric.launch()
|
2022-11-02 14:56:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"LT_CLI_USED": "1"}) # pretend we are using the CLI
|
|
|
|
def test_overridden_run_and_cli_not_allowed():
|
2023-01-10 15:02:05 +00:00
|
|
|
class FabricWithRun(Fabric):
|
2022-11-02 14:56:22 +00:00
|
|
|
def run(self):
|
|
|
|
pass
|
|
|
|
|
2023-01-04 15:57:18 +00:00
|
|
|
with pytest.raises(TypeError, match=escape("Overriding `Fabric.run()` and launching from the CLI is not allowed")):
|
2023-01-10 15:02:05 +00:00
|
|
|
FabricWithRun()
|
2022-11-10 01:06:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_module_sharding_context():
|
|
|
|
"""Test that the sharding context manager gets applied when the strategy supports it and is a no-op
|
|
|
|
otherwise."""
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric = EmptyFabric()
|
|
|
|
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
|
|
|
|
with fabric.sharded_model():
|
2022-11-10 01:06:39 +00:00
|
|
|
pass
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy.module_sharded_context.assert_not_called()
|
2022-11-10 01:06:39 +00:00
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy = MagicMock(spec=_Sharded)
|
|
|
|
with fabric.sharded_model():
|
2022-11-10 01:06:39 +00:00
|
|
|
pass
|
2023-01-10 15:02:05 +00:00
|
|
|
fabric._strategy.module_sharded_context.assert_called_once()
|
2023-01-04 15:57:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_callbacks_input():
|
|
|
|
"""Test the various ways in which callbacks can be registered with Fabric."""
|
|
|
|
callback0 = Mock()
|
|
|
|
callback1 = Mock()
|
|
|
|
|
|
|
|
# single callback
|
|
|
|
fabric = Fabric(callbacks=callback0)
|
|
|
|
assert fabric._callbacks == [callback0]
|
|
|
|
|
|
|
|
# multiple callbacks
|
|
|
|
fabric = Fabric(callbacks=[callback0, callback1])
|
|
|
|
assert fabric._callbacks == [callback0, callback1]
|
|
|
|
|
|
|
|
|
2023-01-09 18:33:18 +00:00
|
|
|
def test_call():
|
2023-01-04 15:57:18 +00:00
|
|
|
"""Test that `fabric.call` triggers the callback implementations."""
|
|
|
|
callback0 = Mock()
|
|
|
|
callback1 = Mock()
|
|
|
|
fabric = Fabric(callbacks=[callback0, callback1])
|
|
|
|
|
|
|
|
# No arguments
|
|
|
|
fabric.call("on_train_end")
|
|
|
|
callback0.on_train_end.assert_called_once()
|
|
|
|
callback1.on_train_end.assert_called_once()
|
|
|
|
|
|
|
|
# Optional arguments
|
|
|
|
fabric.call("on_train_end", "positional", keyword="keyword")
|
|
|
|
callback0.on_train_end.assert_called_with("positional", keyword="keyword")
|
|
|
|
callback1.on_train_end.assert_called_with("positional", keyword="keyword")
|
|
|
|
|
|
|
|
# Some callbacks don't implement the requested hook
|
|
|
|
callback0 = Mock()
|
|
|
|
callback1 = Mock(spec_set={}) # `on_train_end` not defined for this callback
|
|
|
|
fabric = Fabric(callbacks=[callback0, callback1])
|
|
|
|
fabric.call("on_train_end")
|
|
|
|
callback0.on_train_end.assert_called_once()
|
|
|
|
assert not callback1.mock_calls # no methods were called on callback1
|
|
|
|
|
|
|
|
# Skip callback attributes that are not callable
|
|
|
|
callback = Mock(not_a_method=1)
|
|
|
|
fabric = Fabric(callbacks=[callback])
|
|
|
|
with pytest.warns(UserWarning, match="Skipping the callback `Mock.not_a_method`"):
|
|
|
|
fabric.call("not_a_method")
|
|
|
|
assert not callback1.mock_calls
|
2023-01-09 18:33:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_loggers_input():
|
|
|
|
"""Test the various ways in which loggers can be registered with Fabric."""
|
|
|
|
logger0 = Mock()
|
|
|
|
logger1 = Mock()
|
|
|
|
|
|
|
|
# no logger
|
|
|
|
fabric = Fabric(loggers=None)
|
|
|
|
assert fabric._loggers == []
|
|
|
|
fabric = Fabric(loggers=[])
|
|
|
|
assert fabric._loggers == []
|
|
|
|
|
|
|
|
# single logger
|
|
|
|
fabric = Fabric(loggers=logger0)
|
|
|
|
assert fabric._loggers == [logger0]
|
|
|
|
|
|
|
|
# multiple loggers
|
|
|
|
fabric = Fabric(loggers=[logger0, logger1])
|
|
|
|
assert fabric._loggers == [logger0, logger1]
|
|
|
|
|
|
|
|
|
|
|
|
def test_log():
|
|
|
|
"""Test that `fabric.log` sends the metrics to each logger."""
|
|
|
|
|
|
|
|
logger0 = Mock()
|
|
|
|
logger1 = Mock()
|
|
|
|
fabric = Fabric(loggers=[logger0, logger1])
|
|
|
|
|
|
|
|
fabric.log("test", 1)
|
|
|
|
logger0.log_metrics.assert_called_with(metrics={"test": 1}, step=None)
|
|
|
|
logger1.log_metrics.assert_called_with(metrics={"test": 1}, step=None)
|
|
|
|
|
|
|
|
fabric.log("test", 2, step=15)
|
|
|
|
logger0.log_metrics.assert_called_with(metrics={"test": 2}, step=15)
|
|
|
|
logger1.log_metrics.assert_called_with(metrics={"test": 2}, step=15)
|
|
|
|
|
|
|
|
|
|
|
|
def test_log_dict():
|
|
|
|
"""Test that `fabric.log_dict` sends the metrics dict to each logger."""
|
|
|
|
|
|
|
|
logger0 = Mock()
|
|
|
|
logger1 = Mock()
|
|
|
|
fabric = Fabric(loggers=[logger0, logger1])
|
|
|
|
|
|
|
|
fabric.log_dict({"foo": 1, "bar": 2}, step=None)
|
|
|
|
logger0.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None)
|
|
|
|
logger1.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None)
|
|
|
|
|
|
|
|
fabric.log_dict({"foo": 3, "bar": 4}, step=15)
|
|
|
|
logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
|
|
|
|
logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
|
2023-01-10 16:11:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_log_dict_input_parsing():
|
|
|
|
"""Test validation of input data types and preprocessing."""
|
|
|
|
logger = Mock()
|
|
|
|
fabric = Fabric(loggers=[logger])
|
|
|
|
|
|
|
|
# Tensor scalar, 0 dims
|
|
|
|
fabric.log("log", torch.tensor(1))
|
|
|
|
logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None)
|
|
|
|
fabric.log_dict({"log_dict": torch.tensor(1)})
|
|
|
|
logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None)
|
|
|
|
|
|
|
|
# Tensor scalar, 1 dims
|
|
|
|
fabric.log("log", torch.tensor([2]))
|
|
|
|
logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None)
|
|
|
|
fabric.log_dict({"log_dict": torch.tensor([2])})
|
|
|
|
logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None)
|
|
|
|
|
|
|
|
# Tensor, multiple dims
|
|
|
|
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
|
|
|
|
fabric.log("log", torch.tensor([3, 4]))
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
|
|
|
|
fabric.log_dict({"log_dict": torch.tensor([3, 4])})
|