lightning/tests/tests_fabric/test_fabric.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1237 lines
48 KiB
Python
Raw Normal View History

2021-10-29 21:46:39 +00:00
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from re import escape
2021-10-29 21:46:39 +00:00
from unittest import mock
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
2021-10-29 21:46:39 +00:00
import pytest
import torch
import torch.distributed
import torch.nn.functional
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.strategies import (
DataParallelStrategy,
DDPStrategy,
DeepSpeedStrategy,
ParallelStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from lightning_utilities.test.warning import no_warning_call
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset
from tests_fabric.helpers.runif import RunIf
2021-10-29 21:46:39 +00:00
class BoringModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2, bias=False)
def forward(self, x):
x = self.layer(x)
return torch.nn.functional.mse_loss(x, torch.ones_like(x))
def test_run_input_output():
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
class RunFabric(Fabric):
run_args = ()
run_kwargs = {}
2021-10-29 21:46:39 +00:00
def run(self, *args, **kwargs):
self.run_args = args
self.run_kwargs = kwargs
return "result"
fabric = RunFabric()
result = fabric.run(1, 2, three=3)
2021-10-29 21:46:39 +00:00
assert result == "result"
assert fabric.run_args == (1, 2)
assert fabric.run_kwargs == {"three": 3}
2021-10-29 21:46:39 +00:00
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module(ddp_mock, setup_method):
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
fabric = Fabric(accelerator="cpu", strategy="ddp", devices=2)
fabric._launched = True # pretend we have launched multiple processes
model = nn.Linear(1, 2)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(model)
ddp_mock.assert_called_with(module=model, device_ids=ANY)
assert fabric_model.module == model
assert fabric_model.weight is model.weight
assert fabric_model.forward != model.forward
@RunIf(skip_windows=True, dynamo=True)
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_compiled_module(setup_method):
"""Test that an `OptimizedModule` can be passed to the setup method."""
from torch._dynamo.eval_frame import OptimizedModule
fabric = Fabric(devices=1)
model = nn.Linear(1, 2)
compiled_model = torch.compile(model)
assert isinstance(compiled_model, OptimizedModule)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(compiled_model)
assert fabric_model.module == compiled_model
# Attributes get passed through
assert fabric_model.weight is model.weight
@pytest.mark.parametrize(
("accelerator", "initial_device", "target_device"),
[
("cpu", "cpu", "cpu"),
pytest.param("cpu", "cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param("cpu", "mps:0", "cpu", marks=RunIf(mps=True)),
pytest.param("cuda", "cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda", "cuda:1", "cuda:0", marks=RunIf(min_cuda_gpus=2)),
pytest.param("mps", "cpu", "mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize("move_to_device", [True, False])
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device):
"""Test that `move_to_device` leads to parameters being moved to the correct device and that the device attributes
on the wrapper are updated."""
initial_device = torch.device(initial_device)
target_device = torch.device(target_device)
expected_device = target_device if move_to_device else initial_device
fabric = Fabric(accelerator=accelerator, devices=1)
model = nn.Linear(1, 2)
model.to(initial_device)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(model, move_to_device=move_to_device)
# all parameters on the expected device
assert all(param.device == expected_device for param in model.parameters())
assert all(param.device == expected_device for param in fabric_model.parameters())
assert fabric_model.device == expected_device
assert fabric.device == target_device
# edge case: model has no parameters
model = nn.Sequential()
fabric_model = setup_method(model, move_to_device=move_to_device)
assert fabric_model.device == target_device if move_to_device else torch.device("cpu")
@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("move_to_device", [True, False])
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_parameters_on_different_devices(setup_method, move_to_device):
"""Test that a warning is emitted when model parameters are on a different device prior to calling `setup()`."""
device0 = torch.device("cpu")
device1 = torch.device("cuda", 0)
fabric = Fabric(accelerator="cuda", devices=1)
module0 = nn.Linear(1, 2).to(device0)
module1 = nn.Linear(1, 2).to(device1)
model = nn.Sequential(module0, module1)
setup_method = getattr(fabric, setup_method)
match = r"has 2 parameters on different devices \(for example '1.weight' on cuda:0 and '0.weight' on cpu\)"
if move_to_device:
with pytest.warns(PossibleUserWarning, match=match):
fabric_model = setup_method(model, move_to_device=move_to_device)
# both have the same device now
assert fabric_model.device == device1
assert module0.weight.device == module0.bias.device == device1
assert module1.weight.device == module1.bias.device == device1
else:
with no_warning_call(expected_warning=PossibleUserWarning, match=match):
setup_method(model, move_to_device=move_to_device)
def test_setup_module_and_optimizers():
"""Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
fabric = Fabric(devices=1)
2021-10-29 21:46:39 +00:00
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
# no optimizer
fabric_model = fabric.setup(model)
assert isinstance(fabric_model, _FabricModule)
assert fabric_model.module is model
2021-10-29 21:46:39 +00:00
# single optimizer
fabric_model, fabric_optimizer = fabric.setup(model, optimizer0)
assert isinstance(fabric_model, _FabricModule)
assert isinstance(fabric_optimizer, _FabricOptimizer)
assert fabric_model.module is model
assert fabric_optimizer.optimizer is optimizer0
2021-10-29 21:46:39 +00:00
# multiple optimizers
fabric_model, fabric_optimizer0, fabric_optimizer1 = fabric.setup(model, optimizer0, optimizer1)
assert isinstance(fabric_model, _FabricModule)
assert isinstance(fabric_optimizer0, _FabricOptimizer)
assert isinstance(fabric_optimizer1, _FabricOptimizer)
assert fabric_model.module is model
assert fabric_optimizer0.optimizer is optimizer0
assert fabric_optimizer1.optimizer is optimizer1
2021-10-29 21:46:39 +00:00
def test_setup_optimizers():
"""Test that `setup_optimizers()` can handle one or more optimizers."""
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
# single optimizer
fabric_optimizer = fabric.setup_optimizers(optimizer0)
assert isinstance(fabric_optimizer, _FabricOptimizer)
assert fabric_optimizer.optimizer is optimizer0
# multiple optimizers
fabric_optimizer0, fabric_optimizer1 = fabric.setup_optimizers(optimizer0, optimizer1)
assert isinstance(fabric_optimizer0, _FabricOptimizer)
assert isinstance(fabric_optimizer1, _FabricOptimizer)
assert fabric_optimizer0.optimizer is optimizer0
assert fabric_optimizer1.optimizer is optimizer1
2021-10-29 21:46:39 +00:00
def test_setup_twice_fails():
"""Test that calling `setup` with a model or optimizer that is already wrapped fails."""
fabric = Fabric(devices=1)
2021-10-29 21:46:39 +00:00
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
with pytest.raises(ValueError, match="A model should be passed only once to the"):
fabric.setup(fabric_model, optimizer)
2021-10-29 21:46:39 +00:00
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
with pytest.raises(ValueError, match="An optimizer should be passed only once to the"):
fabric.setup(model, fabric_optimizer)
2021-10-29 21:46:39 +00:00
def test_setup_module_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = Fabric(devices=1)
model = nn.Linear(1, 2)
fabric_model = fabric.setup_module(model)
with pytest.raises(ValueError, match="A model should be passed only once to the"):
fabric.setup_module(fabric_model)
def test_setup_optimizers_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
fabric_optimizer = fabric.setup_optimizers(optimizer)
with pytest.raises(ValueError, match="An optimizer should be passed only once to"):
fabric.setup_optimizers(fabric_optimizer)
@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, XLAStrategy])
def test_setup_optimizers_not_supported(strategy_cls):
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers independently."""
fabric = Fabric()
fabric._launched = True # pretend we have launched multiple processes
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
fabric._strategy = Mock(spec=strategy_cls)
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
fabric.setup_optimizers(optimizer)
@RunIf(min_cuda_gpus=1, min_torch="2.1")
def test_setup_optimizer_on_meta_device():
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
fabric._launched = True # pretend we have launched multiple processes
with fabric.init_module(empty_init=True):
model = nn.Linear(1, 2)
assert model.weight.is_meta
optimizer = torch.optim.Adam(model.parameters()) # optimizer references meta device params
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup(model, optimizer)
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup_optimizers(optimizer)
2021-10-29 21:46:39 +00:00
def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model."""
fabric = Fabric(devices=1)
2021-10-29 21:46:39 +00:00
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
assert fabric._models_setup == 0
fabric.setup(model, optimizer)
assert fabric._models_setup == 1
2021-10-29 21:46:39 +00:00
fabric.setup(model, optimizer)
assert fabric._models_setup == 2
2021-10-29 21:46:39 +00:00
fabric.setup_module(model)
assert fabric._models_setup == 3
2021-10-29 21:46:39 +00:00
def test_setup_dataloaders_unsupported_input():
2021-10-29 21:46:39 +00:00
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
fabric = Fabric()
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
fabric.setup_dataloaders()
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
fabric.setup_dataloaders(range(2)) # type: ignore
2021-10-29 21:46:39 +00:00
def test_setup_dataloaders_return_type():
"""Test that the setup method returns the dataloaders wrapped as FabricDataLoader and in the right order."""
fabric = Fabric(devices=1)
2021-10-29 21:46:39 +00:00
# single dataloader
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(2)))
assert isinstance(fabric_dataloader, _FabricDataLoader)
2021-10-29 21:46:39 +00:00
# multiple dataloaders
dataset0 = Mock()
dataset1 = Mock()
dataloader0 = DataLoader(dataset0)
dataloader1 = DataLoader(dataset1)
fabric_dataloader0, fabric_dataloader1 = fabric.setup_dataloaders(dataloader0, dataloader1)
assert isinstance(fabric_dataloader0, _FabricDataLoader)
assert isinstance(fabric_dataloader1, _FabricDataLoader)
assert fabric_dataloader0.dataset is dataset0
assert fabric_dataloader1.dataset is dataset1
2021-10-29 21:46:39 +00:00
@mock.patch("lightning.fabric.fabric._replace_dunder_methods")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Fabric intercepts the DataLoader constructor arguments with a context manager when launching a
function."""
def run(_):
# One for BatchSampler, another for DataLoader
assert ctx_manager().__enter__.call_count == 2
fabric = Fabric()
fabric.launch(run)
assert ctx_manager().__exit__.call_count == 2
def test_setup_dataloaders_raises_for_unknown_custom_args():
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's run
method."""
class CustomDataLoader(DataLoader):
def __init__(self, new_arg, *args, **kwargs):
super().__init__(range(5), *args, **kwargs)
dataloader = CustomDataLoader(2, batch_size=2)
# If no distributed sampler is required, reinstantiation is not necessary
fabric = Fabric(devices=1)
fabric_dataloader = fabric.setup_dataloaders(dataloader)
assert fabric_dataloader._dataloader is dataloader
# If a distributed sampler is required, sampler needs to be reinstantiatied
fabric = Fabric(devices=2, accelerator="cpu")
fabric._launched = True
with pytest.raises(
MisconfigurationException,
match=(
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
r"The missing attributes are \['new_arg'\]"
),
):
fabric.setup_dataloaders(dataloader)
2021-10-29 21:46:39 +00:00
def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
fabric = Fabric()
2021-10-29 21:46:39 +00:00
dataloader = DataLoader(range(2))
fabric_dataloader = fabric.setup_dataloaders(dataloader)
2021-10-29 21:46:39 +00:00
with pytest.raises(ValueError, match="A dataloader should be passed only once to the"):
fabric.setup_dataloaders(fabric_dataloader)
2021-10-29 21:46:39 +00:00
@mock.patch(
"lightning.fabric.fabric.Fabric.device",
2021-10-29 21:46:39 +00:00
new_callable=PropertyMock,
return_value=torch.device("cuda", 1),
)
def test_setup_dataloaders_move_to_device(fabric_device_mock):
"""Test that the setup configures FabricDataLoader to move the data to the device automatically."""
fabric = Fabric(devices=1)
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False)
assert all(dl.device is None for dl in fabric_dataloaders)
fabric_device_mock.assert_not_called()
2021-10-29 21:46:39 +00:00
fabric = Fabric(devices=1)
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True)
assert all(dl.device == torch.device("cuda", 1) for dl in fabric_dataloaders)
fabric_device_mock.assert_called()
2021-10-29 21:46:39 +00:00
def test_setup_dataloaders_distributed_sampler_not_needed():
"""Test that `use_distributed_sampler` option has no effect when no distributed sampler is needed."""
2021-10-29 21:46:39 +00:00
custom_sampler = Mock(spec=Sampler)
dataloader = DataLoader(Mock(), sampler=custom_sampler)
# if no distributed sampler is required, dataloader reinstantiation is not necessary
fabric = Fabric(devices=1)
fabric_dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)
assert fabric_dataloader._dataloader is dataloader
assert fabric_dataloader.sampler is custom_sampler
2021-10-29 21:46:39 +00:00
def test_setup_dataloaders_distributed_sampler_shuffle():
"""Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed
sampler."""
fabric = Fabric(accelerator="cpu", strategy="ddp_spawn", devices=2)
# no fabric.launch(): pretend we are on rank 0 now
fabric._launched = True
dataset = TensorDataset(torch.arange(8))
# shuffling turned off
no_shuffle_dataloaders = [
DataLoader(dataset),
DataLoader(dataset, shuffle=False),
DataLoader(dataset, sampler=SequentialSampler(dataset)),
]
for dataloader in no_shuffle_dataloaders:
dataloader = fabric.setup_dataloaders(dataloader)
assert [t[0].item() for t in iter(dataloader)] == [0, 2, 4, 6]
# shuffling turned on
shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))]
for dataloader in shuffle_dataloaders:
seed_everything(1)
dataloader = fabric.setup_dataloaders(dataloader)
assert [t[0].item() for t in iter(dataloader)] == [5, 2, 7, 1]
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("batch_size", [1, 2, 3])
def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size):
"""Test that the distributed sampler setup in Fabric leads to the same sequence of data as in raw PyTorch."""
torch.manual_seed(1)
fabric = Fabric(accelerator="cpu", strategy="ddp", devices=2)
# no fabric.launch(): pretend we are on rank 0 now
fabric._launched = True
dataset = torch.arange(10)
torch_dataloader = DataLoader(
dataset,
sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle),
batch_size=batch_size,
)
fabric_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
fabric_dataloader = fabric.setup_dataloaders(fabric_dataloader)
def fetch_epoch(loader):
iterator = iter(loader)
# we fetch 2 batches per epoch
return torch.cat((next(iterator), next(iterator)))
# 1st epoch
# PyTorch users needs to set the epoch, while in Fabric it gets handled automatically
torch_dataloader.sampler.set_epoch(0)
torch_data = fetch_epoch(torch_dataloader)
fabric_data = fetch_epoch(fabric_dataloader)
assert torch.equal(torch_data, fabric_data)
# 2nd epoch
# PyTorch users needs to set the epoch, while in Fabric it gets handled automatically
torch_dataloader.sampler.set_epoch(1)
torch_data = fetch_epoch(torch_dataloader)
fabric_data = fetch_epoch(fabric_dataloader)
assert torch.equal(torch_data, fabric_data)
assert torch_dataloader.sampler.epoch == 1
assert fabric_dataloader._dataloader.sampler.epoch == 1
2021-10-29 21:46:39 +00:00
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_everything():
"""Test that seed everything is static and sets the worker init function on the dataloader."""
Fabric.seed_everything(3)
2021-10-29 21:46:39 +00:00
fabric = Fabric(devices=1)
fabric_dataloader = fabric.setup_dataloaders(DataLoader(Mock()))
2021-10-29 21:46:39 +00:00
assert fabric_dataloader.worker_init_fn.func is pl_worker_init_function
2021-10-29 21:46:39 +00:00
assert os.environ == {"PL_GLOBAL_SEED": "3", "PL_SEED_WORKERS": "1"}
@pytest.mark.parametrize(
"strategy",
[
2023-01-10 22:05:12 +00:00
"dp",
"ddp",
"ddp_spawn",
pytest.param("ddp_fork", marks=RunIf(skip_windows=True)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
2021-10-29 21:46:39 +00:00
],
)
def test_setup_dataloaders_replace_custom_sampler(strategy):
"""Test that asking to replace a custom sampler results in an error when a distributed sampler would be needed."""
2021-10-29 21:46:39 +00:00
custom_sampler = Mock(spec=Sampler)
dataloader = DataLoader(Mock(), sampler=custom_sampler)
# explicitly asking to replace when a custom sampler is already configured raises an exception
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
fabric._launched = True # pretend we have launched multiple processes
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)
2021-10-29 21:46:39 +00:00
# setting `use_distributed_sampler=False` leaves the sampler untouched
fabric_dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
assert fabric_dataloader.sampler is custom_sampler
2021-10-29 21:46:39 +00:00
@pytest.mark.parametrize(
"strategy",
[
2023-01-10 22:05:12 +00:00
"dp",
"ddp",
"ddp_spawn",
pytest.param("ddp_fork", marks=RunIf(skip_windows=True)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
2021-10-29 21:46:39 +00:00
],
)
@pytest.mark.parametrize("shuffle", [True, False])
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
fabric._launched = True # pretend we have launched multiple processes
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)
2021-10-29 21:46:39 +00:00
@pytest.mark.parametrize(
("accelerator", "expected"),
2021-10-29 21:46:39 +00:00
[
("cpu", "cpu"),
Merge different gpu backends with accelerator='gpu' (#13642) * Rename GPUAccelerator to CUDAAccelerator * Add back GPUAccelerator and deprecate it * Remove temporary registration * accelerator connector reroute * accelerator_connector tests * update enums * lite support + tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move "gpu" support up before actual accelerator flag checks * Stupid arguments * fix tests * change exception type * fix registry test * pre-commit * CI: debug HPU flow (#13419) * Update the hpu-tests.yml to pull docker from vault * fire & sudo * habana-gaudi-hpus * Check the driver status on gaudi server (#13718) Co-authored-by: arao <arao@habana.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> * Update typing-extensions requirement from <4.2.1,>=4.0.0 to >=4.0.0,<4.3.1 in /requirements (#13529) Update typing-extensions requirement in /requirements Updates the requirements on [typing-extensions](https://github.com/python/typing_extensions) to permit the latest version. - [Release notes](https://github.com/python/typing_extensions/releases) - [Changelog](https://github.com/python/typing_extensions/blob/main/CHANGELOG.md) - [Commits](https://github.com/python/typing_extensions/compare/4.0.0...4.3.0) --- updated-dependencies: - dependency-name: typing-extensions dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit suggestions (#13540) updates: - [github.com/psf/black: 22.3.0 → 22.6.0](https://github.com/psf/black/compare/22.3.0...22.6.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [FIX] Native FSDP precision + tests (#12985) * Simplify fetching's loader types (#13111) * Include app templates to the lightning and app packages (#13731) * Include app templates to the package Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py (#13617) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix typos initialize in docs (#13557) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` (#12832) * Fix mypy errors attributed to `pytorch_lightning.loggers.tensorboard.py` (#13688) Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix mypy errors attributed to `pytorch_lightning.loggers.mlflow` (#13691) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> * fix mypy errors for loggers/wandb.py (#13483) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> * Fix gatekeeper minimum check (#13769) * changelog * changelog * fix order * move up again * add missing test Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: arao <arao@habana.ai> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sean Naren <sean@grid.ai> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Mansy <ahmed.mansy156@gmail.com> Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Lee Jungwon <33821003+BongYang@users.noreply.github.com> Co-authored-by: Nathaniel D'Amours <88633026+NathanielDamours@users.noreply.github.com> Co-authored-by: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Gautier Dagan <s2234411@ed.ac.uk> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
2022-07-25 14:46:45 +00:00
pytest.param("cuda", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
2023-04-19 14:39:00 +00:00
pytest.param("tpu", "xla:0", marks=RunIf(tpu=True, standalone=True)),
pytest.param("mps", "mps:0", marks=RunIf(mps=True)),
Merge different gpu backends with accelerator='gpu' (#13642) * Rename GPUAccelerator to CUDAAccelerator * Add back GPUAccelerator and deprecate it * Remove temporary registration * accelerator connector reroute * accelerator_connector tests * update enums * lite support + tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move "gpu" support up before actual accelerator flag checks * Stupid arguments * fix tests * change exception type * fix registry test * pre-commit * CI: debug HPU flow (#13419) * Update the hpu-tests.yml to pull docker from vault * fire & sudo * habana-gaudi-hpus * Check the driver status on gaudi server (#13718) Co-authored-by: arao <arao@habana.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> * Update typing-extensions requirement from <4.2.1,>=4.0.0 to >=4.0.0,<4.3.1 in /requirements (#13529) Update typing-extensions requirement in /requirements Updates the requirements on [typing-extensions](https://github.com/python/typing_extensions) to permit the latest version. - [Release notes](https://github.com/python/typing_extensions/releases) - [Changelog](https://github.com/python/typing_extensions/blob/main/CHANGELOG.md) - [Commits](https://github.com/python/typing_extensions/compare/4.0.0...4.3.0) --- updated-dependencies: - dependency-name: typing-extensions dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit suggestions (#13540) updates: - [github.com/psf/black: 22.3.0 → 22.6.0](https://github.com/psf/black/compare/22.3.0...22.6.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [FIX] Native FSDP precision + tests (#12985) * Simplify fetching's loader types (#13111) * Include app templates to the lightning and app packages (#13731) * Include app templates to the package Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py (#13617) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix typos initialize in docs (#13557) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` (#12832) * Fix mypy errors attributed to `pytorch_lightning.loggers.tensorboard.py` (#13688) Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix mypy errors attributed to `pytorch_lightning.loggers.mlflow` (#13691) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> * fix mypy errors for loggers/wandb.py (#13483) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> * Fix gatekeeper minimum check (#13769) * changelog * changelog * fix order * move up again * add missing test Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: arao <arao@habana.ai> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sean Naren <sean@grid.ai> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Mansy <ahmed.mansy156@gmail.com> Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Lee Jungwon <33821003+BongYang@users.noreply.github.com> Co-authored-by: Nathaniel D'Amours <88633026+NathanielDamours@users.noreply.github.com> Co-authored-by: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Gautier Dagan <s2234411@ed.ac.uk> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
2022-07-25 14:46:45 +00:00
pytest.param("gpu", "mps:0", marks=RunIf(mps=True)),
2021-10-29 21:46:39 +00:00
],
)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
2021-10-29 21:46:39 +00:00
def test_to_device(accelerator, expected):
"""Test that the to_device method can move various objects to the device determined by the accelerator."""
if accelerator == "tpu" and not _using_pjrt():
expected = "xla:1"
2021-10-29 21:46:39 +00:00
fabric = Fabric(accelerator=accelerator, devices=1)
fabric.launch()
expected_device = torch.device(expected)
# module
module = torch.nn.Linear(2, 3)
module = fabric.to_device(module)
assert all(param.device == expected_device for param in module.parameters())
2021-10-29 21:46:39 +00:00
# tensor
tensor = torch.rand(2, 2)
tensor = fabric.to_device(tensor)
assert tensor.device == expected_device
2021-10-29 21:46:39 +00:00
# collection
collection = {"data": torch.rand(2, 2), "int": 1}
collection = fabric.to_device(collection)
assert collection["data"].device == expected_device
2021-10-29 21:46:39 +00:00
def test_rank_properties():
"""Test that the rank properties are determined by the strategy."""
fabric = Fabric()
fabric._strategy = Mock(spec=Strategy)
fabric._strategy.world_size = 1000
assert fabric.world_size == 1000
fabric._strategy.global_rank = 100
assert fabric.global_rank == 100
fabric._strategy.local_rank = 10
assert fabric.local_rank == 10
fabric._strategy.node_rank = 1
assert fabric.node_rank == 1
2021-10-29 21:46:39 +00:00
def test_backward():
"""Test that backward() calls into the precision plugin."""
fabric = Fabric()
fabric._strategy = Mock(spec=Precision)
2021-10-29 21:46:39 +00:00
loss = Mock()
fabric.backward(loss, "arg", keyword="kwarg")
fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
2021-10-29 21:46:39 +00:00
@RunIf(deepspeed=True, mps=False)
2021-10-29 21:46:39 +00:00
def test_backward_model_input_required():
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""
fabric = Fabric(strategy="deepspeed", devices=1)
fabric._launched = True # pretend we have launched
2021-10-29 21:46:39 +00:00
model0 = nn.Linear(1, 2)
model1 = nn.Linear(1, 2)
optimizer0 = torch.optim.Adam(model0.parameters())
optimizer1 = torch.optim.Adam(model1.parameters())
fabric._strategy.setup_module_and_optimizers = lambda *args: args
2021-10-29 21:46:39 +00:00
fabric.setup(model0, optimizer0)
fabric.setup(model1, optimizer1)
2021-10-29 21:46:39 +00:00
loss = model0(torch.randn(1, 1, device=fabric.device)).sum()
2021-10-29 21:46:39 +00:00
with pytest.raises(ValueError, match="please provide the model used to perform"):
fabric.backward(loss)
2021-10-29 21:46:39 +00:00
def test_autocast():
"""Test that the Fabric autocast context manager lets the precision plugin handle casting."""
fabric = Fabric()
fabric._precision.forward_context = MagicMock()
2021-10-29 21:46:39 +00:00
fabric._precision.forward_context().__enter__.assert_not_called()
with fabric.autocast():
fabric._precision.forward_context().__enter__.assert_called()
fabric._precision.forward_context().__exit__.assert_called()
def test_no_backward_sync():
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
fabric = Fabric(devices=1)
model = nn.Linear(3, 3)
2023-04-24 21:57:08 +00:00
with pytest.raises(TypeError, match="You need to set up the model first"), fabric.no_backward_sync(model):
pass
model = fabric.setup(model)
# pretend that the strategy does not support skipping backward sync
fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None)
2023-04-24 21:57:08 +00:00
with pytest.warns(
PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"
), fabric.no_backward_sync(model):
pass
# for single-device strategies, it becomes a no-op without warning
fabric._strategy = Mock(spec=SingleDeviceStrategy, _backward_sync_control=MagicMock())
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# same for XLA
fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock())
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# pretend that the strategy supports skipping backward sync
fabric._strategy = Mock(_backward_sync_control=MagicMock())
# disabling the context manager makes it a no-op
with fabric.no_backward_sync(model, enabled=False):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# when enabled, the wrapped module gets passed down
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
def test_launch_without_function():
"""Test the various ways `Fabric.launch()` can be called."""
# default: no launcher, single process
fabric = Fabric()
nothing = Mock()
fabric.launch(nothing)
nothing.assert_called()
# with a launcher on the strategy
fabric = Fabric()
fabric._strategy._launcher = Mock()
fabric.launch()
fabric._strategy._launcher.launch.assert_called()
def test_launch_with_function():
"""Test the various ways `Fabric.launch(function)` can be called."""
def fn_without_args():
pass
fabric = Fabric()
with pytest.raises(TypeError, match="needs to take at least one argument"):
fabric.launch(fn_without_args)
def fn_with_one_arg(arg):
assert isinstance(arg, Fabric)
fn_with_one_arg.called = True
fabric = Fabric()
fabric.launch(fn_with_one_arg)
assert fn_with_one_arg.called
# common user mistake
fabric = Fabric()
with pytest.raises(TypeError, match="needs to be a callable"):
fabric.launch(fn_with_one_arg(fabric))
@mock.patch.dict(os.environ, {"LT_CLI_USED": "1"}) # pretend we are using the CLI
def test_launch_and_cli_not_allowed():
fabric = Fabric(devices=1)
with pytest.raises(RuntimeError, match=escape("Calling `.launch()` again is not allowed")):
fabric.launch()
@RunIf(mps=False)
2023-05-05 06:25:15 +00:00
@pytest.mark.parametrize("strategy", ["xla", "ddp_spawn"])
def test_launch_and_strategies_unsupported_combinations(strategy, xla_available):
fabric = Fabric(strategy=strategy)
with pytest.raises(TypeError, match=r"launch\(\)` needs to be called with a function"):
fabric.launch()
@mock.patch.dict(os.environ, {"LT_CLI_USED": "1"}) # pretend we are using the CLI
def test_overridden_run_and_cli_not_allowed():
class FabricWithRun(Fabric):
def run(self):
pass
with pytest.raises(TypeError, match=escape("Overriding `Fabric.run()` and launching from the CLI is not allowed")):
FabricWithRun()
def test_module_sharding_context():
"""Test that the sharding context manager gets applied when the strategy supports it and is a no-op otherwise."""
fabric = Fabric()
fabric._launched = True
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
pass
fabric._strategy.module_sharded_context.assert_not_called()
fabric._strategy = MagicMock(spec=_Sharded)
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
pass
fabric._strategy.module_sharded_context.assert_called_once()
def test_init_module_context(monkeypatch):
"""Test that the strategy returns the context manager for initializing the module."""
import lightning.fabric
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.module_init_context = Mock(wraps=strategy.module_init_context)
fabric._strategy = strategy
with fabric.init_module():
pass
strategy.module_init_context.assert_called_once_with(empty_init=None)
strategy.module_init_context.reset_mock()
# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
with fabric.init_module():
pass
strategy.module_init_context.assert_called_once()
def test_init_tensor_context(monkeypatch):
"""Test that `.init_tensor()` warns if using PyTorch < 2.0."""
import lightning.fabric
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
fabric._strategy = strategy
with fabric.init_tensor():
pass
strategy.tensor_init_context.assert_called_once()
strategy.tensor_init_context.reset_mock()
# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place tensors on the device directly"): # noqa: SIM117
with fabric.init_tensor():
pass
strategy.tensor_init_context.assert_called_once()
def test_callbacks_input():
"""Test the various ways in which callbacks can be registered with Fabric."""
callback0 = Mock()
callback1 = Mock()
# single callback
fabric = Fabric(callbacks=callback0)
assert fabric._callbacks == [callback0]
# multiple callbacks
fabric = Fabric(callbacks=[callback0, callback1])
assert fabric._callbacks == [callback0, callback1]
def test_call():
"""Test that `fabric.call` triggers the callback implementations."""
callback0 = Mock()
callback1 = Mock()
fabric = Fabric(callbacks=[callback0, callback1])
# No arguments
fabric.call("on_train_end")
callback0.on_train_end.assert_called_once()
callback1.on_train_end.assert_called_once()
# Optional arguments
fabric.call("on_train_end", "positional", keyword="keyword")
callback0.on_train_end.assert_called_with("positional", keyword="keyword")
callback1.on_train_end.assert_called_with("positional", keyword="keyword")
# Some callbacks don't implement the requested hook
callback0 = Mock()
callback1 = Mock(spec_set={}) # `on_train_end` not defined for this callback
fabric = Fabric(callbacks=[callback0, callback1])
fabric.call("on_train_end")
callback0.on_train_end.assert_called_once()
assert not callback1.mock_calls # no methods were called on callback1
# Skip callback attributes that are not callable
callback = Mock(not_a_method=1)
fabric = Fabric(callbacks=[callback])
with pytest.warns(UserWarning, match="Skipping the callback `Mock.not_a_method`"):
fabric.call("not_a_method")
assert not callback1.mock_calls
def test_special_callbacks():
"""Tests special callbacks that have hooks for internal Fabric events."""
class SpecialCallback:
def on_after_optimizer_step(self, strategy, optimizer):
pass
def on_after_setup(self, fabric, module):
pass
callback = Mock(wraps=SpecialCallback())
fabric = Fabric(accelerator="cpu", callbacks=[callback])
model = torch.nn.Linear(2, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model)
model(torch.randn(2, 2)).sum().backward()
fabric_optimizer.step()
callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer)
def test_loggers_input():
"""Test the various ways in which loggers can be registered with Fabric."""
logger0 = Mock()
logger1 = Mock()
# no logger
fabric = Fabric(loggers=None)
assert fabric._loggers == []
fabric = Fabric(loggers=[])
assert fabric._loggers == []
# single logger
fabric = Fabric(loggers=logger0)
assert fabric._loggers == [logger0]
# multiple loggers
fabric = Fabric(loggers=[logger0, logger1])
assert fabric._loggers == [logger0, logger1]
def test_log():
"""Test that `fabric.log` sends the metrics to each logger."""
logger0 = Mock()
logger1 = Mock()
fabric = Fabric(loggers=[logger0, logger1])
fabric.log("test", 1)
logger0.log_metrics.assert_called_with(metrics={"test": 1}, step=None)
logger1.log_metrics.assert_called_with(metrics={"test": 1}, step=None)
fabric.log("test", 2, step=15)
logger0.log_metrics.assert_called_with(metrics={"test": 2}, step=15)
logger1.log_metrics.assert_called_with(metrics={"test": 2}, step=15)
def test_log_dict():
"""Test that `fabric.log_dict` sends the metrics dict to each logger."""
logger0 = Mock()
logger1 = Mock()
fabric = Fabric(loggers=[logger0, logger1])
fabric.log_dict({"foo": 1, "bar": 2}, step=None)
logger0.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None)
logger1.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None)
fabric.log_dict({"foo": 3, "bar": 4}, step=15)
logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)
def test_log_dict_input_parsing():
"""Test validation of input data types and preprocessing."""
logger = Mock()
fabric = Fabric(loggers=[logger])
# Tensor scalar, 0 dims
fabric.log("log", torch.tensor(1))
logger.log_metrics.assert_called_with(metrics={"log": 1}, step=None)
fabric.log_dict({"log_dict": torch.tensor(1)})
logger.log_metrics.assert_called_with(metrics={"log_dict": 1}, step=None)
# Tensor scalar, 1 dims
fabric.log("log", torch.tensor([2]))
logger.log_metrics.assert_called_with(metrics={"log": 2}, step=None)
fabric.log_dict({"log_dict": torch.tensor([2])})
logger.log_metrics.assert_called_with(metrics={"log_dict": 2}, step=None)
# Tensor, multiple dims
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
fabric.log("log", torch.tensor([3, 4]))
with pytest.raises(ValueError, match="it cannot be converted to a scalar."):
fabric.log_dict({"log_dict": torch.tensor([3, 4])})
@pytest.mark.parametrize("setup", [True, False])
def test_save_wrapped_objects(setup, tmp_path):
"""Test that when modules and optimizers are in the state, they get unwrapped properly."""
fabric = Fabric(devices=1)
save_checkpoint_mock = Mock()
fabric.strategy.save_checkpoint = save_checkpoint_mock
unwrapped_model = BoringModel()
unwrapped_optimizer = torch.optim.Adam(unwrapped_model.parameters())
if setup:
model, optimizer = fabric.setup(unwrapped_model, unwrapped_optimizer)
assert isinstance(model, _FabricModule)
assert isinstance(optimizer, _FabricOptimizer)
else:
model, optimizer = unwrapped_model, unwrapped_optimizer
anything = {"cocofruit": 1}
state = {"model": model, "optimizer": optimizer, "anything": anything}
expected = {"model": unwrapped_model, "optimizer": unwrapped_optimizer, "anything": anything}
fabric.save(tmp_path, state)
save_checkpoint_mock.assert_called_with(state=expected, path=tmp_path, filter=None)
def test_save_filter(tmp_path):
fabric = Fabric(devices=1)
checkpoint_io_mock = Mock()
fabric.strategy.checkpoint_io = checkpoint_io_mock
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
anything = {"cocofruit": 1}
state = {"model": model, "optimizer": optimizer, "anything": anything, "foo": 1}
save_path = tmp_path / "foo.pth"
# filter all dicts
filter = {k: lambda k, v: False for k in state}
fabric.save(save_path, state, filter=filter)
checkpoint_io_mock.save_checkpoint.assert_called_with(checkpoint={"foo": 1}, path=save_path, storage_options=None)
# bad filters
with pytest.raises(TypeError, match="should be a dict"):
fabric.save(save_path, state, filter="foo")
with pytest.raises(TypeError, match="callable, given 'foo"):
fabric.save(save_path, state, filter={"model": "foo"})
with pytest.raises(ValueError, match="keys {'asd'} are not present in the state keys"):
fabric.save(save_path, state, filter={"asd": lambda k, v: True})
# subset
checkpoint_io_mock.reset_mock()
filter = {
"model": lambda k, v: "weight" in k,
"anything": lambda k, v: isinstance(v, int),
"optimizer": lambda k, v: "param_groups" in k,
}
fabric.save(save_path, state, filter=filter)
checkpoint_io_mock.save_checkpoint.assert_called_with(
checkpoint={"model": {"layer.weight": ANY}, "optimizer": {"param_groups": ANY}, "anything": anything, "foo": 1},
path=save_path,
storage_options=None,
)
2023-01-24 22:35:00 +00:00
@pytest.mark.parametrize("setup", [True, False])
def test_load_wrapped_objects(setup, tmp_path):
"""Test that loading happens in-place for model, optimizer, and other user data."""
fabric = Fabric(accelerator="cpu")
expected_remainder = {"extra": "data"}
def mocked_load_checkpoint(path, state, strict):
assert not isinstance(state["model"], _FabricModule)
assert not isinstance(state["optimizer"], _FabricOptimizer)
state.update({"int": 5, "dict": {"x": 1}})
return expected_remainder
fabric.strategy.load_checkpoint = mocked_load_checkpoint
unwrapped_model = BoringModel()
unwrapped_optimizer = torch.optim.Adam(unwrapped_model.parameters())
if setup:
model, optimizer = fabric.setup(unwrapped_model, unwrapped_optimizer)
assert isinstance(model, _FabricModule)
assert isinstance(optimizer, _FabricOptimizer)
else:
model, optimizer = unwrapped_model, unwrapped_optimizer
state = {"model": model, "optimizer": optimizer, "int": 0, "dict": {"x": 0}}
expected = {"model": model, "optimizer": optimizer, "int": 5, "dict": {"x": 1}}
remainder = fabric.load(tmp_path, state)
assert state == expected
assert remainder == expected_remainder
def test_load_raw():
"""Test that `Fabric.load_raw()` unwraps the object to load and calls into the strategy."""
fabric = Fabric(accelerator="cpu")
fabric.strategy.load_checkpoint = Mock()
model = torch.nn.Linear(2, 2)
optimizer = torch.optim.Adam(model.parameters())
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)
fabric.load_raw(path="path0", obj=model)
fabric.strategy.load_checkpoint.assert_called_with(path="path0", state=model, strict=True)
fabric.load_raw(path="path1", obj=wrapped_model, strict=False)
fabric.strategy.load_checkpoint.assert_called_with(path="path1", state=model, strict=False)
fabric.load_raw(path="path2", obj=wrapped_optimizer)
fabric.strategy.load_checkpoint.assert_called_with(path="path2", state=optimizer, strict=True)
2023-01-24 22:35:00 +00:00
def test_barrier():
"""Test that `Fabric.barrier()` calls into the strategy."""
fabric = Fabric()
fabric._strategy = Mock()
fabric._launched = True
2023-01-24 22:35:00 +00:00
fabric.barrier("test")
fabric._strategy.barrier.assert_called_once_with(name="test")
def test_broadcast():
"""Test that `Fabric.broadcast()` calls into the strategy."""
fabric = Fabric()
fabric._strategy = Mock()
fabric._launched = True
2023-01-24 22:35:00 +00:00
fabric.broadcast(torch.tensor(1), src=2)
fabric._strategy.broadcast.assert_called_once_with(torch.tensor(1), src=2)
def test_all_gather():
"""Test that `Fabric.all_gather()` applies itself to collections and calls into the strategy."""
fabric = Fabric()
fabric._strategy = Mock(root_device=torch.device("cpu"))
fabric._launched = True
defaults = {"group": None, "sync_grads": False}
2023-01-24 22:35:00 +00:00
# single tensor
fabric.all_gather(torch.tensor(1))
fabric._strategy.all_gather.assert_called_once_with(torch.tensor(1), **defaults)
fabric._strategy.reset_mock()
# list
fabric.all_gather([torch.tensor(2), torch.tensor(3), "string"])
fabric._strategy.all_gather.assert_has_calls([call(torch.tensor(2), **defaults), call(torch.tensor(3), **defaults)])
fabric._strategy.reset_mock()
# dict
fabric.all_gather({"a": torch.tensor(4), "b": [torch.tensor(5)], "c": "string"})
fabric._strategy.all_gather.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
def test_all_reduce():
"""Test that `Fabric.all_reduce()` applies itself to collections and calls into the strategy."""
fabric = Fabric()
fabric._strategy = Mock(root_device=torch.device("cpu"))
fabric._launched = True
defaults = {"group": None, "reduce_op": "mean"}
2023-01-24 22:35:00 +00:00
# single tensor
fabric.all_reduce(torch.tensor(1))
fabric._strategy.all_reduce.assert_called_once_with(torch.tensor(1), **defaults)
fabric._strategy.reset_mock()
# list
fabric.all_reduce([torch.tensor(2), torch.tensor(3), "string"])
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(2), **defaults), call(torch.tensor(3), **defaults)])
fabric._strategy.reset_mock()
# dict
fabric.all_reduce({"a": torch.tensor(4), "b": [torch.tensor(5)], "c": "string"})
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
def test_rank_zero_first():
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""
def record_calls_for_rank(rank):
call_order = []
fabric = Fabric()
fabric._strategy = Mock(global_rank=rank)
fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier"))
target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run")))
with fabric.rank_zero_first():
target.run()
return call_order
assert record_calls_for_rank(0) == ["run", "barrier", "barrier"]
assert record_calls_for_rank(1) == ["barrier", "run", "barrier"]
@pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)])
def test_grad_clipping(clip_val, max_norm):
fabric = Fabric(devices=1)
fabric.strategy.clip_gradients_norm = Mock()
fabric.strategy.clip_gradients_value = Mock()
torch_model = nn.Linear(1, 1)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-3)
model, optimizer = fabric.setup(torch_model, torch_optimizer)
loss = model(torch.rand(1, 1).to(fabric.device))
fabric.backward(loss)
fabric.strategy.clip_gradients_value.assert_not_called()
fabric.strategy.clip_gradients_norm.assert_not_called()
fabric.clip_gradients(model, optimizer, max_norm=max_norm, clip_val=clip_val)
if clip_val is not None:
fabric.strategy.clip_gradients_value.assert_called_once_with(torch_model, torch_optimizer, clip_val=clip_val)
fabric.strategy.clip_gradients_norm.assert_not_called()
else:
fabric.strategy.clip_gradients_value.assert_not_called()
fabric.strategy.clip_gradients_norm.assert_called_once_with(
torch_model, torch_optimizer, max_norm=max_norm, norm_type=2.0, error_if_nonfinite=True
)
def test_verify_launch_called():
"""Test that the user gets an error message if they forgot to call `.launch()`."""
fabric = Fabric(accelerator="cpu")
assert not fabric._launched
fabric._strategy = Mock(spec=SingleDeviceStrategy)
fabric._validate_launched()
fabric._strategy = Mock(spec=DataParallelStrategy)
fabric._validate_launched()
fabric._strategy = Mock(spec=DDPStrategy)
with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"):
fabric._validate_launched()
# Methods
method_names = ("setup", "setup_module", "setup_dataloaders", "broadcast", "barrier", "all_reduce", "all_gather")
for method_name in method_names:
method = getattr(fabric, method_name)
with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"):
method(Mock())
# Context managers
ctx_manager_names = ("init_module",)
for ctx_manager_name in ctx_manager_names:
ctx_manager = getattr(fabric, ctx_manager_name)
with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"), ctx_manager():
pass # the error is raised in the context manager and caught by `pytest.raises`
fabric.launch()
assert fabric._launched
fabric._validate_launched()
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
@pytest.mark.parametrize(
"kwargs",
[
{},
pytest.param({"precision": "16-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
pytest.param({"precision": "64-true"}, marks=pytest.mark.xfail(raises=RuntimeError, match="Unsupported")),
],
)
def test_fabric_with_torchdynamo_fullgraph(kwargs):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(10, 10)
def forward(self, x):
# forward gets compiled
assert torch._dynamo.is_compiling()
return self.l(x)
def fn(model, x):
assert torch._dynamo.is_compiling()
a = x * 10
return model(a)
fabric = Fabric(devices=1, accelerator="cpu", **kwargs)
model = MyModel()
fmodel = fabric.setup(model)
# we are compiling a function that calls model.forward() inside
cfn = torch.compile(fn, fullgraph=True)
x = torch.randn(10, 10, device=fabric.device)
# pass the fabric wrapped model to the compiled function, so that it gets compiled too
out = cfn(fmodel, x)
assert isinstance(out, torch.Tensor)