From 9ff7d7120b87def139ca67671109366921f4c683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Jun 2023 12:32:32 +0200 Subject: [PATCH] Add `rank_zero_first` utility (#17784) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../advanced/distributed_communication.rst | 19 +++++++++++++++- .../fabric/image_classifier/train_fabric.py | 12 +++++----- examples/fabric/kfold_cv/train_fabric.py | 9 +++----- src/lightning/fabric/fabric.py | 22 +++++++++++++++++++ tests/tests_fabric/test_fabric.py | 20 +++++++++++++++++ 5 files changed, 69 insertions(+), 13 deletions(-) diff --git a/docs/source-fabric/advanced/distributed_communication.rst b/docs/source-fabric/advanced/distributed_communication.rst index 5635d19051..83aac5cb3f 100644 --- a/docs/source-fabric/advanced/distributed_communication.rst +++ b/docs/source-fabric/advanced/distributed_communication.rst @@ -83,6 +83,11 @@ Avoid this from happening by guarding your logic with a rank check: if fabric.local_rank == 0: download_dataset() +Another type of race condition is when one or multiple processes try to access a resource before it is available. +For example, when rank 0 downloads a dataset, all other processes should *wait* for the download to complete before they start reading the contents. +This can be achieved with a **barrier**. + + ---- @@ -127,7 +132,19 @@ Since downloading should be done on rank 0 only to :ref:`avoid race conditions < fabric.barrier() # After everyone reached the barrier, they can access the downloaded files: - load_dataset() + dataset = load_dataset() + + +Specifically for the use case of downloading and reading data, there is a convenience context manager that combines both the rank-check and the barrier: + +.. code-block:: python + + with fabric.rank_zero_first(): + if not dataset_exists(): + download_dataset("http://...") + dataset = load_dataset() + +With :meth:`~lightning.fabric.fabric.Fabric.rank_zero_first`, it is guaranteed that process 0 executes the code block first before all others can enter it. ---- diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index e88a8beafb..5f4d9313c6 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -78,12 +78,12 @@ def run(hparams): seed_everything(hparams.seed) # instead of torch.manual_seed(...) transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) - # This is meant to ensure the data are download only by 1 process. - if fabric.is_global_zero: - MNIST(DATASETS_PATH, download=True) - fabric.barrier() - train_dataset = MNIST(DATASETS_PATH, train=True, transform=transform) - test_dataset = MNIST(DATASETS_PATH, train=False, transform=transform) + + # Let rank 0 download the data first, then everyone will load MNIST + with fabric.rank_zero_first(): + train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform) + test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform) + train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=hparams.batch_size, diff --git a/examples/fabric/kfold_cv/train_fabric.py b/examples/fabric/kfold_cv/train_fabric.py index 79da9f378e..ffaa11fc41 100644 --- a/examples/fabric/kfold_cv/train_fabric.py +++ b/examples/fabric/kfold_cv/train_fabric.py @@ -115,13 +115,10 @@ def run(hparams): seed_everything(hparams.seed) # instead of torch.manual_seed(...) transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) - # This is meant to ensure the data are download only by 1 process. - if fabric.is_global_zero: - MNIST(DATASETS_PATH, download=True) - fabric.barrier() - # initialize dataset - dataset = MNIST(DATASETS_PATH, train=True, transform=transform) + # Let rank 0 download the data first, then everyone will load MNIST + with fabric.rank_zero_first(): + dataset = MNIST(DATASETS_PATH, train=True, transform=transform) # Loop over different folds (shuffle = False by default so reproducible) folds = hparams.folds diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 8798c723ee..bf635da9f7 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -544,6 +544,28 @@ class Fabric: data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, Tensor, self._strategy.all_reduce, group=group, reduce_op=reduce_op) + @contextmanager + def rank_zero_first(self, local: bool = False) -> Generator: + """The code block under this context manager gets executed first on the main process (rank 0) and only when + completed, the other processes get to run the code in parallel. + + Args: + local: Set this to ``True`` if the **local** rank should be the one going first. Useful if you are + downloading data and the filesystem isn't shared between the nodes. + + Example:: + + with fabric.rank_zero_first(): + dataset = MNIST("datasets/", download=True) + """ + rank = self.local_rank if local else self.global_rank + if rank > 0: + self.barrier() + yield + if rank == 0: + self.barrier() + self.barrier() + @contextmanager def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Generator: """Skip gradient synchronization during backward to avoid redundant communication overhead. diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 99bf3a22c0..51f0c903fd 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -1038,6 +1038,26 @@ def test_all_reduce(): fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)]) +def test_rank_zero_first(): + """Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`.""" + + def record_calls_for_rank(rank): + call_order = [] + + fabric = Fabric() + fabric._strategy = Mock(global_rank=rank) + fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier")) + target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run"))) + + with fabric.rank_zero_first(): + target.run() + + return call_order + + assert record_calls_for_rank(0) == ["run", "barrier", "barrier"] + assert record_calls_for_rank(1) == ["barrier", "run", "barrier"] + + @pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)]) def test_grad_clipping(clip_val, max_norm): fabric = Fabric(devices=1)