Add `rank_zero_first` utility (#17784)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-06-12 12:32:32 +02:00 committed by GitHub
parent e4a3be7d75
commit 9ff7d7120b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 13 deletions

View File

@ -83,6 +83,11 @@ Avoid this from happening by guarding your logic with a rank check:
if fabric.local_rank == 0:
download_dataset()
Another type of race condition is when one or multiple processes try to access a resource before it is available.
For example, when rank 0 downloads a dataset, all other processes should *wait* for the download to complete before they start reading the contents.
This can be achieved with a **barrier**.
----
@ -127,7 +132,19 @@ Since downloading should be done on rank 0 only to :ref:`avoid race conditions <
fabric.barrier()
# After everyone reached the barrier, they can access the downloaded files:
load_dataset()
dataset = load_dataset()
Specifically for the use case of downloading and reading data, there is a convenience context manager that combines both the rank-check and the barrier:
.. code-block:: python
with fabric.rank_zero_first():
if not dataset_exists():
download_dataset("http://...")
dataset = load_dataset()
With :meth:`~lightning.fabric.fabric.Fabric.rank_zero_first`, it is guaranteed that process 0 executes the code block first before all others can enter it.
----

View File

@ -78,12 +78,12 @@ def run(hparams):
seed_everything(hparams.seed) # instead of torch.manual_seed(...)
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
# This is meant to ensure the data are download only by 1 process.
if fabric.is_global_zero:
MNIST(DATASETS_PATH, download=True)
fabric.barrier()
train_dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
test_dataset = MNIST(DATASETS_PATH, train=False, transform=transform)
# Let rank 0 download the data first, then everyone will load MNIST
with fabric.rank_zero_first():
train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)
test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=hparams.batch_size,

View File

@ -115,13 +115,10 @@ def run(hparams):
seed_everything(hparams.seed) # instead of torch.manual_seed(...)
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
# This is meant to ensure the data are download only by 1 process.
if fabric.is_global_zero:
MNIST(DATASETS_PATH, download=True)
fabric.barrier()
# initialize dataset
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
# Let rank 0 download the data first, then everyone will load MNIST
with fabric.rank_zero_first():
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
# Loop over different folds (shuffle = False by default so reproducible)
folds = hparams.folds

View File

@ -544,6 +544,28 @@ class Fabric:
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, Tensor, self._strategy.all_reduce, group=group, reduce_op=reduce_op)
@contextmanager
def rank_zero_first(self, local: bool = False) -> Generator:
"""The code block under this context manager gets executed first on the main process (rank 0) and only when
completed, the other processes get to run the code in parallel.
Args:
local: Set this to ``True`` if the **local** rank should be the one going first. Useful if you are
downloading data and the filesystem isn't shared between the nodes.
Example::
with fabric.rank_zero_first():
dataset = MNIST("datasets/", download=True)
"""
rank = self.local_rank if local else self.global_rank
if rank > 0:
self.barrier()
yield
if rank == 0:
self.barrier()
self.barrier()
@contextmanager
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Generator:
"""Skip gradient synchronization during backward to avoid redundant communication overhead.

View File

@ -1038,6 +1038,26 @@ def test_all_reduce():
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
def test_rank_zero_first():
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""
def record_calls_for_rank(rank):
call_order = []
fabric = Fabric()
fabric._strategy = Mock(global_rank=rank)
fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier"))
target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run")))
with fabric.rank_zero_first():
target.run()
return call_order
assert record_calls_for_rank(0) == ["run", "barrier", "barrier"]
assert record_calls_for_rank(1) == ["barrier", "run", "barrier"]
@pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)])
def test_grad_clipping(clip_val, max_norm):
fabric = Fabric(devices=1)