diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 78728e9c20..415a8f6fc3 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -53,6 +53,7 @@ from lightning.fabric.utilities.imports import ( _TORCH_GREATER_EQUAL_2_0, ) from lightning.fabric.utilities.init import _EmptyInit +from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, ProcessGroup, ReduceOp @@ -572,9 +573,13 @@ class FSDPStrategy(ParallelStrategy): return metadata if _is_full_checkpoint(path): - # TODO: Support lazy-loading here (see Fabric) - checkpoint = torch.load(path, map_location="cpu") - _load_raw_module_state(checkpoint["state_dict"], world_size=self.world_size, module=self.model) + checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu") + _load_raw_module_state(checkpoint.pop("state_dict"), module=self.model, world_size=self.world_size) + + if _TORCH_GREATER_EQUAL_2_0: + # Materialize lazy tensors if there are any left in the checkpoint + # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues + checkpoint = _materialize_tensors(checkpoint) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType diff --git a/tests/tests_pytorch/strategies/test_custom_plugin.py b/tests/tests_pytorch/strategies/test_custom_strategy.py similarity index 100% rename from tests/tests_pytorch/strategies/test_custom_plugin.py rename to tests/tests_pytorch/strategies/test_custom_strategy.py diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 1a06a622d4..42ed143cd3 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -12,117 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from datetime import timedelta from unittest import mock import pytest import torch -from torch.nn.parallel.distributed import DistributedDataParallel +from torch.nn.parallel import DistributedDataParallel -import lightning.pytorch as pl from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 -from lightning.pytorch import seed_everything, Trainer -from lightning.pytorch.callbacks import Callback +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecisionPlugin, HalfPrecisionPlugin, PrecisionPlugin from lightning.pytorch.strategies import DDPStrategy -from tests_pytorch.helpers.datamodules import ClassifDataModule +from lightning.pytorch.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.simple_models import ClassificationModel - - -@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) -def test_multi_gpu_model_ddp_fit_only(tmpdir): - dm = ClassifDataModule() - model = ClassificationModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") - trainer.fit(model, datamodule=dm) - - -@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) -def test_multi_gpu_model_ddp_test_only(tmpdir): - dm = ClassifDataModule() - model = ClassificationModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") - trainer.test(model, datamodule=dm) - - -@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) -def test_multi_gpu_model_ddp_fit_test(tmpdir): - seed_everything(4321) - dm = ClassifDataModule() - model = ClassificationModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") - trainer.fit(model, datamodule=dm) - result = trainer.test(model, datamodule=dm) - - for out in result: - assert out["test_acc"] > 0.7 - - -@RunIf(skip_windows=True) -@mock.patch("torch.cuda.set_device") -@mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision") -@mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory") -def test_ddp_torch_dist_is_available_in_setup(_, __, ___, cuda_count_1, mps_count_0, tmpdir): - """Test to ensure torch distributed is available within the setup hook using ddp.""" - - class TestModel(BoringModel): - def setup(self, stage: str) -> None: - assert torch.distributed.is_initialized() - raise SystemExit() - - model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - strategy=DDPStrategy(process_group_backend="gloo"), - accelerator="gpu", - devices=1, - ) - with pytest.raises(SystemExit): - trainer.fit(model) - - -@RunIf(min_cuda_gpus=2, standalone=True) -@pytest.mark.parametrize("precision", ["16-mixed", "32-true"]) -def test_ddp_wrapper(tmpdir, precision): - """Test parameters to ignore are carried over for DDP.""" - - class WeirdModule(torch.nn.Module): - def _save_to_state_dict(self, destination, prefix, keep_vars): - return {"something": "something"} - - class CustomModel(BoringModel): - def __init__(self): - super().__init__() - self.weird_module = WeirdModule() - - # should be skip. - self._ddp_params_and_buffers_to_ignore = ["something"] - - class CustomCallback(Callback): - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - assert isinstance(trainer.strategy.model, DistributedDataParallel) - expected = ["something"] - assert ( - trainer.strategy.model.parameters_to_ignore == set(expected) if _TORCH_GREATER_EQUAL_2_0 else expected - ) - assert trainer.strategy.model.module._ddp_params_and_buffers_to_ignore == expected - - model = CustomModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - precision=precision, - strategy="ddp", - accelerator="gpu", - devices=2, - callbacks=CustomCallback(), - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) @pytest.mark.parametrize( @@ -153,6 +57,27 @@ def test_ddp_process_group_backend(process_group_backend, device_str, expected_p @pytest.mark.parametrize( ("strategy_name", "expected_ddp_kwargs"), [ + ("ddp_spawn", {}), + pytest.param("ddp_fork", {}, marks=RunIf(skip_windows=True)), + pytest.param("ddp_notebook", {}, marks=RunIf(skip_windows=True)), + ("ddp_spawn_find_unused_parameters_false", {"find_unused_parameters": False}), + ("ddp_spawn_find_unused_parameters_true", {"find_unused_parameters": True}), + pytest.param( + "ddp_fork_find_unused_parameters_false", {"find_unused_parameters": False}, marks=RunIf(skip_windows=True) + ), + pytest.param( + "ddp_fork_find_unused_parameters_true", {"find_unused_parameters": True}, marks=RunIf(skip_windows=True) + ), + pytest.param( + "ddp_notebook_find_unused_parameters_false", + {"find_unused_parameters": False}, + marks=RunIf(skip_windows=True), + ), + pytest.param( + "ddp_notebook_find_unused_parameters_true", + {"find_unused_parameters": True}, + marks=RunIf(skip_windows=True), + ), ("ddp", {}), ("ddp_find_unused_parameters_false", {"find_unused_parameters": False}), ("ddp_find_unused_parameters_true", {"find_unused_parameters": True}), @@ -187,3 +112,78 @@ def test_tensor_init_context(precision_plugin, expected_dtype): module = torch.nn.Linear(2, 2) assert module.weight.device == module.bias.device == expected_device assert module.weight.dtype == module.bias.dtype == expected_dtype + + +@mock.patch("torch.distributed.init_process_group") +def test_set_timeout(mock_init_process_group): + """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" + test_timedelta = timedelta(seconds=30) + model = BoringModel() + ddp_strategy = DDPStrategy(timeout=test_timedelta) + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + strategy=ddp_strategy, + ) + # test wrap the model if fitting + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + + process_group_backend = trainer.strategy._get_process_group_backend() + global_rank = trainer.strategy.cluster_environment.global_rank() + world_size = trainer.strategy.cluster_environment.world_size() + mock_init_process_group.assert_called_with( + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + ) + + +@RunIf(skip_windows=True) +def test_ddp_configure_ddp(mps_count_0): + """Tests with ddp strategy.""" + model = BoringModel() + ddp_strategy = DDPStrategy() + trainer = Trainer( + max_epochs=1, + strategy=ddp_strategy, + ) + # test wrap the model if fitting + trainer.state.fn = TrainerFn.FITTING + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + assert isinstance(trainer.model, LightningModule) + trainer.strategy.setup(trainer) + # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel + assert isinstance(trainer.model, DistributedDataParallel) + + ddp_strategy = DDPStrategy() + trainer = Trainer( + max_epochs=1, + strategy=ddp_strategy, + ) + # test do not wrap the model if TrainerFn is not fitting + trainer.state.fn = TrainerFn.VALIDATING + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + trainer.strategy.setup(trainer) + # in DDPStrategy configure_ddp(), model are still LightningModule + assert isinstance(trainer.model, LightningModule) + + +@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize("trainer_fn", [TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING]) +def test_ddp_dont_configure_sync_batchnorm(trainer_fn): + model = BoringModel() + model.layer = torch.nn.BatchNorm1d(10) + ddp_strategy = DDPStrategy() + trainer = Trainer(accelerator="gpu", devices=1, strategy=ddp_strategy, sync_batchnorm=True) + trainer.state.fn = trainer_fn + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + assert isinstance(trainer.model, LightningModule) + trainer.strategy.setup(trainer) + # because TrainerFn is not FITTING, model is not configured with sync batchnorm + assert not isinstance(trainer.strategy.model.layer, torch.nn.modules.batchnorm.SyncBatchNorm) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py new file mode 100644 index 0000000000..985b9500b5 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -0,0 +1,446 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock +from unittest.mock import Mock + +import pytest +import torch +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel + +import lightning.pytorch as pl +import tests_pytorch.helpers.pipelines as tpipes +from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback, EarlyStopping +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher +from lightning.pytorch.strategies.launchers.multiprocessing import _MultiProcessingLauncher +from lightning.pytorch.trainer import seed_everything +from tests_pytorch.helpers.datamodules import ClassifDataModule +from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.simple_models import ClassificationModel + + +@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) +def test_multi_gpu_model_ddp_fit_only(tmpdir): + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.fit(model, datamodule=dm) + + +@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) +def test_multi_gpu_model_ddp_test_only(tmpdir): + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.test(model, datamodule=dm) + + +@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True) +def test_multi_gpu_model_ddp_fit_test(tmpdir): + seed_everything(4321) + dm = ClassifDataModule() + model = ClassificationModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp") + trainer.fit(model, datamodule=dm) + result = trainer.test(model, datamodule=dm) + + for out in result: + assert out["test_acc"] > 0.7 + + +@RunIf(skip_windows=True) +@mock.patch("torch.cuda.set_device") +@mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision") +@mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory") +def test_ddp_torch_dist_is_available_in_setup(_, __, ___, cuda_count_1, mps_count_0, tmpdir): + """Test to ensure torch distributed is available within the setup hook using ddp.""" + + class TestModel(BoringModel): + def setup(self, stage: str) -> None: + assert torch.distributed.is_initialized() + raise SystemExit() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + strategy=DDPStrategy(process_group_backend="gloo"), + accelerator="gpu", + devices=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize("precision", ["16-mixed", "32-true"]) +def test_ddp_wrapper(tmpdir, precision): + """Test parameters to ignore are carried over for DDP.""" + + class WeirdModule(torch.nn.Module): + def _save_to_state_dict(self, destination, prefix, keep_vars): + return {"something": "something"} + + class CustomModel(BoringModel): + def __init__(self): + super().__init__() + self.weird_module = WeirdModule() + + # should be skipped + self._ddp_params_and_buffers_to_ignore = ["something"] + + class CustomCallback(Callback): + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert isinstance(trainer.strategy.model, DistributedDataParallel) + expected = ["something"] + assert ( + trainer.strategy.model.parameters_to_ignore == set(expected) if _TORCH_GREATER_EQUAL_2_0 else expected + ) + assert trainer.strategy.model.module._ddp_params_and_buffers_to_ignore == expected + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + precision=precision, + strategy="ddp", + accelerator="gpu", + devices=2, + callbacks=CustomCallback(), + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model) + + +@RunIf(min_cuda_gpus=2, sklearn=True) +def test_multi_gpu_early_stop_ddp_spawn(tmpdir): + seed_everything(42) + + trainer_options = { + "default_root_dir": tmpdir, + "callbacks": [EarlyStopping(monitor="train_acc")], + "max_epochs": 50, + "limit_train_batches": 10, + "limit_val_batches": 10, + "accelerator": "gpu", + "devices": [0, 1], + "strategy": "ddp_spawn", + } + + dm = ClassifDataModule() + model = ClassificationModel() + tpipes.run_model_test(trainer_options, model, dm) + + +@RunIf(min_cuda_gpus=2) +def test_multi_gpu_model_ddp_spawn(tmpdir): + seed_everything(42) + + trainer_options = { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 10, + "limit_val_batches": 10, + "accelerator": "gpu", + "devices": [0, 1], + "strategy": "ddp_spawn", + "enable_progress_bar": False, + } + + model = BoringModel() + + tpipes.run_model_test(trainer_options, model) + + +@RunIf(min_cuda_gpus=2) +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.2, + limit_val_batches=0.2, + accelerator="gpu", + devices=[0, 1], + strategy="ddp_spawn", + ) + trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader()) + assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()." + + +class UnusedParametersModel(BoringModel): + def __init__(self): + super().__init__() + self.intermediate_layer = torch.nn.Linear(32, 32) + + def training_step(self, batch, batch_idx): + with torch.no_grad(): + batch = self.intermediate_layer(batch) + return super().training_step(batch, batch_idx) + + +def test_find_unused_parameters_exception(): + """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users.""" + trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2) + with pytest.raises( + ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in" + ): + trainer.fit(UnusedParametersModel()) + + trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2) + with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"): + trainer.fit(UnusedParametersModel()) + + +class BoringCallbackDDPSpawnModel(BoringModel): + def __init__(self, name: str, val: float): + super().__init__() + self.name = name + self.val = val + + def validation_step(self, batch, batch_idx): + self.log(self.name, self.val) + return super().validation_step(batch, batch_idx) + + +class CustomMultiProcessingLauncher(_MultiProcessingLauncher): + def get_extra_results(self, trainer): + extra = super().get_extra_results(trainer) + extra["test_val"] = "test_val" + return extra + + def update_main_process_results(self, trainer, extra) -> None: + trainer.strategy.test_val = extra.pop("test_val") + return super().update_main_process_results(trainer, extra) + + +class TestDDPSpawnStrategy(DDPStrategy): + def _configure_launcher(self): + self._launcher = CustomMultiProcessingLauncher(self) + + +@RunIf(skip_windows=True) +def test_ddp_spawn_add_get_queue(tmpdir): + """Tests get_extra_results/update_main_process_results with DDPSpawnStrategy.""" + ddp_spawn_strategy = TestDDPSpawnStrategy() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, accelerator="cpu", devices=2, strategy=ddp_spawn_strategy + ) + + val: float = 1.0 + val_name: str = "val_acc" + model = BoringCallbackDDPSpawnModel(val_name, val) + dm = BoringDataModule() + trainer.fit(model, datamodule=dm) + assert trainer.callback_metrics[val_name] == torch.tensor(val) + assert ddp_spawn_strategy.test_val == "test_val" + + +class BoringModelDDPCPU(BoringModel): + def on_train_start(self) -> None: + # make sure that the model is on CPU when training + assert self.device == torch.device("cpu") + + +@RunIf(skip_windows=True) +def test_ddp_cpu(): + """Tests if device is set correctly when training for DDPStrategy.""" + trainer = Trainer(devices=2, strategy="ddp_spawn", accelerator="cpu", fast_dev_run=True) + # assert strategy attributes for device setting + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer.strategy.root_device == torch.device("cpu") + model = BoringModelDDPCPU() + trainer.fit(model) + + +class BoringZeroRedundancyOptimizerModel(BoringModel): + def configure_optimizers(self): + return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) + + +@RunIf(min_cuda_gpus=2, skip_windows=True) +@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"]) +def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer.""" + model = BoringZeroRedundancyOptimizerModel() + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + # need to broadcast because tmpdir is different on each process + checkpoint_path = trainer.strategy.broadcast(checkpoint_path) + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param) + + +def test_configure_launcher_create_processes_externally(): + class MyClusterEnvironment(ClusterEnvironment): + @property + def creates_processes_externally(self): + return True + + @property + def main_address(self): + return "" + + @property + def main_port(self): + return 8080 + + @staticmethod + def detect(): + return True + + def world_size(self): + return 1 + + def set_world_size(self): + pass + + def global_rank(self): + return 0 + + def set_global_rank(self): + pass + + def local_rank(self): + return 0 + + def node_rank(self): + return 0 + + ddp_strategy = DDPStrategy(cluster_environment=MyClusterEnvironment()) + assert ddp_strategy.launcher is None + ddp_strategy._configure_launcher() + assert isinstance(ddp_strategy.launcher, _SubprocessScriptLauncher) + + ddp_strategy.launcher._call_children_scripts = Mock() + launch_fn = Mock() + ddp_strategy.launcher.launch(launch_fn) + ddp_strategy.launcher._call_children_scripts.assert_not_called() + launch_fn.assert_called_once() + + +class CheckOptimizerDeviceModel(BoringModel): + def configure_optimizers(self): + assert all(param.device.type == "cuda" for param in self.parameters()) + super().configure_optimizers() + + +@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize("strategy", ["ddp", "ddp_spawn"]) +def test_model_parameters_on_device_for_optimizer(strategy): + """Test that the strategy has moved the parameters to the device by the time the optimizer gets created.""" + model = CheckOptimizerDeviceModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + fast_dev_run=1, + accelerator="gpu", + devices=1, + strategy=strategy, + ) + trainer.fit(model) + + +class BoringModelGPU(BoringModel): + def on_train_start(self) -> None: + # make sure that the model is on GPU when training + assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}") + self.start_cuda_memory = torch.cuda.memory_allocated() + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +def test_ddp_with_2_gpus(): + """Tests if device is set correctly when training and after teardown for DDPStrategy.""" + trainer = Trainer( + accelerator="gpu", + devices=2, + strategy="ddp", + fast_dev_run=True, + enable_progress_bar=False, + enable_model_summary=False, + ) + # assert strategy attributes for device setting + assert isinstance(trainer.strategy, DDPStrategy) + local_rank = trainer.strategy.local_rank + assert trainer.strategy.root_device == torch.device(f"cuda:{local_rank}") + + model = BoringModelGPU() + + trainer.fit(model) + + # assert after training, model is moved to CPU and memory is deallocated + assert model.device == torch.device("cpu") + cuda_memory = torch.cuda.memory_allocated() + assert cuda_memory < model.start_cuda_memory + + +@RunIf(min_cuda_gpus=4, standalone=True) +@mock.patch("torch.distributed.barrier") +def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir): + """Test correct usage of barriers when device ids do not start at 0 or are not consecutive.""" + model = BoringModel() + gpus = [1, 3] + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + accelerator="gpu", + devices=gpus, + strategy="ddp", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.fit(model) + barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]]) + + +@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) +def test_incorrect_ddp_script_spawning(tmpdir): + """Test an error message when user accidentally instructs Lightning to spawn children processes on rank > 0.""" + + class WronglyImplementedEnvironment(LightningEnvironment): + @property + def creates_processes_externally(self): + # returning false no matter what means Lightning would spawn also on ranks > 0 new processes + return False + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + strategy="ddp", + accelerator="cpu", + devices=2, + plugins=[WronglyImplementedEnvironment()], + ) + with pytest.raises( + RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." + ): + trainer.fit(model) diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy_with_comm_hook.py b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py similarity index 100% rename from tests/tests_pytorch/strategies/test_ddp_strategy_with_comm_hook.py rename to tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn.py b/tests/tests_pytorch/strategies/test_ddp_spawn.py deleted file mode 100644 index 74f562f3c1..0000000000 --- a/tests/tests_pytorch/strategies/test_ddp_spawn.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from torch.multiprocessing import ProcessRaisedException - -import tests_pytorch.helpers.pipelines as tpipes -from lightning.pytorch.callbacks import EarlyStopping -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer import seed_everything, Trainer -from tests_pytorch.helpers.datamodules import ClassifDataModule -from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.simple_models import ClassificationModel -from tests_pytorch.strategies.test_ddp_strategy import UnusedParametersModel - - -@RunIf(min_cuda_gpus=2, sklearn=True) -def test_multi_gpu_early_stop_ddp_spawn(tmpdir): - seed_everything(42) - - trainer_options = { - "default_root_dir": tmpdir, - "callbacks": [EarlyStopping(monitor="train_acc")], - "max_epochs": 50, - "limit_train_batches": 10, - "limit_val_batches": 10, - "accelerator": "gpu", - "devices": [0, 1], - "strategy": "ddp_spawn", - } - - dm = ClassifDataModule() - model = ClassificationModel() - tpipes.run_model_test(trainer_options, model, dm) - - -@RunIf(min_cuda_gpus=2) -def test_multi_gpu_model_ddp_spawn(tmpdir): - seed_everything(42) - - trainer_options = { - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 10, - "limit_val_batches": 10, - "accelerator": "gpu", - "devices": [0, 1], - "strategy": "ddp_spawn", - "enable_progress_bar": False, - } - - model = BoringModel() - - tpipes.run_model_test(trainer_options, model) - - -@RunIf(min_cuda_gpus=2) -def test_ddp_all_dataloaders_passed_to_fit(tmpdir): - """Make sure DDP works with dataloaders passed to fit()""" - model = BoringModel() - - trainer = Trainer( - default_root_dir=tmpdir, - enable_progress_bar=False, - max_epochs=1, - limit_train_batches=0.2, - limit_val_batches=0.2, - accelerator="gpu", - devices=[0, 1], - strategy="ddp_spawn", - ) - trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader()) - assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()." - - -def test_ddp_spawn_find_unused_parameters_exception(): - """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users.""" - trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2) - with pytest.raises( - ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in" - ): - trainer.fit(UnusedParametersModel()) diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py deleted file mode 100644 index cf8a588638..0000000000 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import timedelta -from unittest import mock - -import pytest -import torch -from torch.nn.parallel.distributed import DistributedDataParallel - -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.strategies.launchers.multiprocessing import _MultiProcessingLauncher -from lightning.pytorch.trainer.states import TrainerFn -from tests_pytorch.helpers.runif import RunIf - - -class BoringModelDDPCPU(BoringModel): - def on_train_start(self) -> None: - # make sure that the model is on CPU when training - assert self.device == torch.device("cpu") - - -class BoringCallbackDDPSpawnModel(BoringModel): - def __init__(self, name: str, val: float): - super().__init__() - self.name = name - self.val = val - - def validation_step(self, batch, batch_idx): - self.log(self.name, self.val) - return super().validation_step(batch, batch_idx) - - -@RunIf(skip_windows=True) -def test_ddp_cpu(): - """Tests if device is set correctly when training for DDPStrategy.""" - trainer = Trainer(devices=2, strategy="ddp_spawn", accelerator="cpu", fast_dev_run=True) - # assert strategy attributes for device setting - assert isinstance(trainer.strategy, DDPStrategy) - assert trainer.strategy.root_device == torch.device("cpu") - model = BoringModelDDPCPU() - trainer.fit(model) - - -class CustomMultiProcessingLauncher(_MultiProcessingLauncher): - def get_extra_results(self, trainer): - extra = super().get_extra_results(trainer) - extra["test_val"] = "test_val" - return extra - - def update_main_process_results(self, trainer, extra) -> None: - trainer.strategy.test_val = extra.pop("test_val") - return super().update_main_process_results(trainer, extra) - - -class TestDDPSpawnStrategy(DDPStrategy): - def _configure_launcher(self): - self._launcher = CustomMultiProcessingLauncher(self) - - -@RunIf(skip_windows=True) -def test_ddp_spawn_add_get_queue(tmpdir): - """Tests get_extra_results/update_main_process_results with DDPSpawnStrategy.""" - ddp_spawn_strategy = TestDDPSpawnStrategy() - trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, accelerator="cpu", devices=2, strategy=ddp_spawn_strategy - ) - - val: float = 1.0 - val_name: str = "val_acc" - model = BoringCallbackDDPSpawnModel(val_name, val) - dm = BoringDataModule() - trainer.fit(model, datamodule=dm) - assert trainer.callback_metrics[val_name] == torch.tensor(val) - assert ddp_spawn_strategy.test_val == "test_val" - - -class BoringModelDDP(BoringModel): - def on_train_start(self) -> None: - """Check if trainer module is wrapped as DistributedDataParallel during training stage.""" - assert isinstance(self.trainer.model, DistributedDataParallel) - - def on_validation_start(self) -> None: - """Check if trainer module remains as LightningModule during test stage.""" - if self.trainer.state.fn == TrainerFn.FITTING: - assert isinstance(self.trainer.model, DistributedDataParallel) - else: - assert isinstance(self.trainer.model, LightningModule) - - def on_test_start(self) -> None: - """Check if trainer module remains as LightningModule during test stage.""" - assert isinstance(self.trainer.model, LightningModule) - - def on_predict_start(self) -> None: - """Check if trainer module remains as LightningModule during prediction stage.""" - assert isinstance(self.trainer.model, LightningModule) - - -@RunIf(skip_windows=True) -def test_ddp_spawn_configure_ddp(tmpdir): - """Tests with ddp spawn strategy.""" - trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=True) - - model = BoringModelDDP() - - trainer.fit(model) - trainer.validate(model, dataloaders=model.val_dataloader()) - trainer.test(model, dataloaders=model.test_dataloader()) - trainer.predict(model, dataloaders=model.predict_dataloader()) - - -@mock.patch("torch.distributed.init_process_group") -def test_ddp_spawn_strategy_set_timeout(mock_init_process_group): - """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" - test_timedelta = timedelta(seconds=30) - model = BoringModel() - ddp_spawn_strategy = DDPStrategy(start_method="spawn", timeout=test_timedelta) - trainer = Trainer( - max_epochs=1, - accelerator="cpu", - strategy=ddp_spawn_strategy, - ) - # test wrap the model if fitting - trainer.state.fn = TrainerFn.FITTING - trainer.strategy.connect(model) - trainer.lightning_module.trainer = trainer - trainer.strategy.setup_environment() - - process_group_backend = trainer.strategy._get_process_group_backend() - global_rank = trainer.strategy.cluster_environment.global_rank() - world_size = trainer.strategy.cluster_environment.world_size() - mock_init_process_group.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta - ) - - -@pytest.mark.parametrize( - ("strategy_name", "expected_ddp_kwargs"), - [ - ("ddp_spawn", {}), - pytest.param("ddp_fork", {}, marks=RunIf(skip_windows=True)), - pytest.param("ddp_notebook", {}, marks=RunIf(skip_windows=True)), - ("ddp_spawn_find_unused_parameters_false", {"find_unused_parameters": False}), - ("ddp_spawn_find_unused_parameters_true", {"find_unused_parameters": True}), - pytest.param( - "ddp_fork_find_unused_parameters_false", {"find_unused_parameters": False}, marks=RunIf(skip_windows=True) - ), - pytest.param( - "ddp_fork_find_unused_parameters_true", {"find_unused_parameters": True}, marks=RunIf(skip_windows=True) - ), - pytest.param( - "ddp_notebook_find_unused_parameters_false", - {"find_unused_parameters": False}, - marks=RunIf(skip_windows=True), - ), - pytest.param( - "ddp_notebook_find_unused_parameters_true", - {"find_unused_parameters": True}, - marks=RunIf(skip_windows=True), - ), - ], -) -def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs, mps_count_0): - trainer = Trainer(strategy=strategy_name) - assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py deleted file mode 100644 index aeabb378d6..0000000000 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from datetime import timedelta -from unittest import mock -from unittest.mock import Mock - -import pytest -import torch -from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.nn.parallel import DistributedDataParallel - -from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher -from lightning.pytorch.trainer.states import TrainerFn -from tests_pytorch.helpers.runif import RunIf - - -class BoringModelGPU(BoringModel): - def on_train_start(self) -> None: - # make sure that the model is on GPU when training - assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}") - self.start_cuda_memory = torch.cuda.memory_allocated() - - -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_ddp_with_2_gpus(): - """Tests if device is set correctly when training and after teardown for DDPStrategy.""" - trainer = Trainer( - accelerator="gpu", - devices=2, - strategy="ddp", - fast_dev_run=True, - enable_progress_bar=False, - enable_model_summary=False, - ) - # assert strategy attributes for device setting - assert isinstance(trainer.strategy, DDPStrategy) - local_rank = trainer.strategy.local_rank - assert trainer.strategy.root_device == torch.device(f"cuda:{local_rank}") - - model = BoringModelGPU() - - trainer.fit(model) - - # assert after training, model is moved to CPU and memory is deallocated - assert model.device == torch.device("cpu") - cuda_memory = torch.cuda.memory_allocated() - assert cuda_memory < model.start_cuda_memory - - -class BarrierModel(BoringModel): - def setup(self, stage=None): - assert not isinstance(self.trainer.strategy.model, DistributedDataParallel) - self.trainer.strategy.barrier("barrier before model is wrapped") - - def on_train_start(self): - assert isinstance(self.trainer.strategy.model, DistributedDataParallel) - self.trainer.strategy.barrier("barrier after model is wrapped") - - -@RunIf(min_cuda_gpus=4, standalone=True) -@mock.patch("torch.distributed.barrier") -def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir): - """Test correct usage of barriers when device ids do not start at 0 or are not consecutive.""" - model = BoringModel() - gpus = [1, 3] - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - accelerator="gpu", - devices=gpus, - strategy="ddp", - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) - barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]]) - - -@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) -def test_incorrect_ddp_script_spawning(tmpdir): - """Test an error message when user accidentally instructs Lightning to spawn children processes on rank > 0.""" - - class WronglyImplementedEnvironment(LightningEnvironment): - @property - def creates_processes_externally(self): - # returning false no matter what means Lightning would spawn also on ranks > 0 new processes - return False - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - strategy="ddp", - accelerator="cpu", - devices=2, - plugins=[WronglyImplementedEnvironment()], - ) - with pytest.raises( - RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." - ): - trainer.fit(model) - - -@RunIf(skip_windows=True) -def test_ddp_configure_ddp(mps_count_0): - """Tests with ddp strategy.""" - model = BoringModel() - ddp_strategy = DDPStrategy() - trainer = Trainer( - max_epochs=1, - strategy=ddp_strategy, - ) - # test wrap the model if fitting - trainer.state.fn = TrainerFn.FITTING - trainer.strategy.connect(model) - trainer.lightning_module.trainer = trainer - trainer.strategy.setup_environment() - assert isinstance(trainer.model, LightningModule) - trainer.strategy.setup(trainer) - # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel - assert isinstance(trainer.model, DistributedDataParallel) - - ddp_strategy = DDPStrategy() - trainer = Trainer( - max_epochs=1, - strategy=ddp_strategy, - ) - # test do not wrap the model if TrainerFn is not fitting - trainer.state.fn = TrainerFn.VALIDATING - trainer.strategy.connect(model) - trainer.lightning_module.trainer = trainer - trainer.strategy.setup_environment() - trainer.strategy.setup(trainer) - # in DDPStrategy configure_ddp(), model are still LightningModule - assert isinstance(trainer.model, LightningModule) - - -@RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize("trainer_fn", [TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING]) -def test_ddp_dont_configure_sync_batchnorm(trainer_fn): - model = BoringModelGPU() - model.layer = torch.nn.BatchNorm1d(10) - ddp_strategy = DDPStrategy() - trainer = Trainer(accelerator="gpu", devices=1, strategy=ddp_strategy, sync_batchnorm=True) - trainer.state.fn = trainer_fn - trainer.strategy.connect(model) - trainer.lightning_module.trainer = trainer - trainer.strategy.setup_environment() - assert isinstance(trainer.model, LightningModule) - trainer.strategy.setup(trainer) - # because TrainerFn is not FITTING, model is not configured with sync batchnorm - assert not isinstance(trainer.strategy.model.layer, torch.nn.modules.batchnorm.SyncBatchNorm) - - -class CheckOptimizerDeviceModel(BoringModel): - def configure_optimizers(self): - assert all(param.device.type == "cuda" for param in self.parameters()) - super().configure_optimizers() - - -@RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize("strategy", ["ddp", "ddp_spawn"]) -def test_model_parameters_on_device_for_optimizer(strategy): - """Test that the strategy has moved the parameters to the device by the time the optimizer gets created.""" - model = CheckOptimizerDeviceModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - fast_dev_run=1, - accelerator="gpu", - devices=1, - strategy=strategy, - ) - trainer.fit(model) - - -def test_configure_launcher_create_processes_externally(): - class MyClusterEnvironment(ClusterEnvironment): - @property - def creates_processes_externally(self): - return True - - @property - def main_address(self): - return "" - - @property - def main_port(self): - return 8080 - - @staticmethod - def detect(): - return True - - def world_size(self): - return 1 - - def set_world_size(self): - pass - - def global_rank(self): - return 0 - - def set_global_rank(self): - pass - - def local_rank(self): - return 0 - - def node_rank(self): - return 0 - - ddp_strategy = DDPStrategy(cluster_environment=MyClusterEnvironment()) - assert ddp_strategy.launcher is None - ddp_strategy._configure_launcher() - assert isinstance(ddp_strategy.launcher, _SubprocessScriptLauncher) - - ddp_strategy.launcher._call_children_scripts = Mock() - launch_fn = Mock() - ddp_strategy.launcher.launch(launch_fn) - ddp_strategy.launcher._call_children_scripts.assert_not_called() - launch_fn.assert_called_once() - - -@mock.patch("torch.distributed.init_process_group") -def test_ddp_strategy_set_timeout(mock_init_process_group): - """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" - test_timedelta = timedelta(seconds=30) - model = BoringModel() - ddp_strategy = DDPStrategy(timeout=test_timedelta) - trainer = Trainer( - max_epochs=1, - accelerator="cpu", - strategy=ddp_strategy, - ) - # test wrap the model if fitting - trainer.strategy.connect(model) - trainer.lightning_module.trainer = trainer - trainer.strategy.setup_environment() - - process_group_backend = trainer.strategy._get_process_group_backend() - global_rank = trainer.strategy.cluster_environment.global_rank() - world_size = trainer.strategy.cluster_environment.world_size() - mock_init_process_group.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta - ) - - -class BoringZeroRedundancyOptimizerModel(BoringModel): - def configure_optimizers(self): - return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) - - -@RunIf(min_cuda_gpus=2, skip_windows=True) -@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"]) -def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): - """Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer.""" - model = BoringZeroRedundancyOptimizerModel() - trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) - - trainer.fit(model) - - checkpoint_path = os.path.join(tmpdir, "model.pt") - # need to broadcast because tmpdir is different on each process - checkpoint_path = trainer.strategy.broadcast(checkpoint_path) - trainer.save_checkpoint(checkpoint_path) - saved_model = BoringModel.load_from_checkpoint(checkpoint_path) - - # Assert model parameters are identical after loading - for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(trained_param.to("cpu"), loaded_param) - - -class UnusedParametersModel(BoringModel): - def __init__(self): - super().__init__() - self.intermediate_layer = torch.nn.Linear(32, 32) - - def training_step(self, batch, batch_idx): - with torch.no_grad(): - batch = self.intermediate_layer(batch) - return super().training_step(batch, batch_idx) - - -def test_ddp_strategy_find_unused_parameters_exception(): - """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users.""" - trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2) - with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"): - trainer.fit(UnusedParametersModel()) diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed.py similarity index 100% rename from tests/tests_pytorch/strategies/test_deepspeed_strategy.py rename to tests/tests_pytorch/strategies/test_deepspeed.py diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index b96455895f..564e17d030 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -807,3 +807,29 @@ def test_save_load_sharded_state_dict(tmp_path): strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, state_dict_type="sharded") trainer = Trainer(**trainer_kwargs, strategy=strategy) trainer.fit(model, ckpt_path=checkpoint_path) + + +@RunIf(min_torch="1.12") +@mock.patch("lightning.pytorch.strategies.fsdp.torch.load") +@mock.patch("lightning.pytorch.strategies.fsdp._lazy_load") +@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state") +def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path): + """Test that loading a single file (full state) is lazy to reduce peak CPU memory usage.""" + model = BoringModel() + checkpoint = {"state_dict": model.state_dict()} + lazy_load_mock.return_value = checkpoint + + strategy = FSDPStrategy() + trainer = Trainer() + model.trainer = trainer + strategy._lightning_module = model + strategy.model = model + + file = tmp_path / "test.ckpt" + file.touch() + + strategy.load_checkpoint(checkpoint_path=file) + if _TORCH_GREATER_EQUAL_2_0: + lazy_load_mock.assert_called_once() + else: + torch_load_mock.assert_called_once() diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device.py similarity index 100% rename from tests/tests_pytorch/strategies/test_single_device_strategy.py rename to tests/tests_pytorch/strategies/test_single_device.py