Add `rank_zero_first` utility (#17784)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e4a3be7d75
commit
9ff7d7120b
|
@ -83,6 +83,11 @@ Avoid this from happening by guarding your logic with a rank check:
|
|||
if fabric.local_rank == 0:
|
||||
download_dataset()
|
||||
|
||||
Another type of race condition is when one or multiple processes try to access a resource before it is available.
|
||||
For example, when rank 0 downloads a dataset, all other processes should *wait* for the download to complete before they start reading the contents.
|
||||
This can be achieved with a **barrier**.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
|
@ -127,7 +132,19 @@ Since downloading should be done on rank 0 only to :ref:`avoid race conditions <
|
|||
fabric.barrier()
|
||||
|
||||
# After everyone reached the barrier, they can access the downloaded files:
|
||||
load_dataset()
|
||||
dataset = load_dataset()
|
||||
|
||||
|
||||
Specifically for the use case of downloading and reading data, there is a convenience context manager that combines both the rank-check and the barrier:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with fabric.rank_zero_first():
|
||||
if not dataset_exists():
|
||||
download_dataset("http://...")
|
||||
dataset = load_dataset()
|
||||
|
||||
With :meth:`~lightning.fabric.fabric.Fabric.rank_zero_first`, it is guaranteed that process 0 executes the code block first before all others can enter it.
|
||||
|
||||
|
||||
----
|
||||
|
|
|
@ -78,12 +78,12 @@ def run(hparams):
|
|||
seed_everything(hparams.seed) # instead of torch.manual_seed(...)
|
||||
|
||||
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
|
||||
# This is meant to ensure the data are download only by 1 process.
|
||||
if fabric.is_global_zero:
|
||||
MNIST(DATASETS_PATH, download=True)
|
||||
fabric.barrier()
|
||||
train_dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
|
||||
test_dataset = MNIST(DATASETS_PATH, train=False, transform=transform)
|
||||
|
||||
# Let rank 0 download the data first, then everyone will load MNIST
|
||||
with fabric.rank_zero_first():
|
||||
train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)
|
||||
test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=hparams.batch_size,
|
||||
|
|
|
@ -115,13 +115,10 @@ def run(hparams):
|
|||
seed_everything(hparams.seed) # instead of torch.manual_seed(...)
|
||||
|
||||
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
|
||||
# This is meant to ensure the data are download only by 1 process.
|
||||
if fabric.is_global_zero:
|
||||
MNIST(DATASETS_PATH, download=True)
|
||||
fabric.barrier()
|
||||
|
||||
# initialize dataset
|
||||
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
|
||||
# Let rank 0 download the data first, then everyone will load MNIST
|
||||
with fabric.rank_zero_first():
|
||||
dataset = MNIST(DATASETS_PATH, train=True, transform=transform)
|
||||
|
||||
# Loop over different folds (shuffle = False by default so reproducible)
|
||||
folds = hparams.folds
|
||||
|
|
|
@ -544,6 +544,28 @@ class Fabric:
|
|||
data = convert_to_tensors(data, device=self.device)
|
||||
return apply_to_collection(data, Tensor, self._strategy.all_reduce, group=group, reduce_op=reduce_op)
|
||||
|
||||
@contextmanager
|
||||
def rank_zero_first(self, local: bool = False) -> Generator:
|
||||
"""The code block under this context manager gets executed first on the main process (rank 0) and only when
|
||||
completed, the other processes get to run the code in parallel.
|
||||
|
||||
Args:
|
||||
local: Set this to ``True`` if the **local** rank should be the one going first. Useful if you are
|
||||
downloading data and the filesystem isn't shared between the nodes.
|
||||
|
||||
Example::
|
||||
|
||||
with fabric.rank_zero_first():
|
||||
dataset = MNIST("datasets/", download=True)
|
||||
"""
|
||||
rank = self.local_rank if local else self.global_rank
|
||||
if rank > 0:
|
||||
self.barrier()
|
||||
yield
|
||||
if rank == 0:
|
||||
self.barrier()
|
||||
self.barrier()
|
||||
|
||||
@contextmanager
|
||||
def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Generator:
|
||||
"""Skip gradient synchronization during backward to avoid redundant communication overhead.
|
||||
|
|
|
@ -1038,6 +1038,26 @@ def test_all_reduce():
|
|||
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
|
||||
|
||||
|
||||
def test_rank_zero_first():
|
||||
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""
|
||||
|
||||
def record_calls_for_rank(rank):
|
||||
call_order = []
|
||||
|
||||
fabric = Fabric()
|
||||
fabric._strategy = Mock(global_rank=rank)
|
||||
fabric.barrier = Mock(side_effect=lambda *_: call_order.append("barrier"))
|
||||
target = Mock(run=Mock(side_effect=lambda *_: call_order.append("run")))
|
||||
|
||||
with fabric.rank_zero_first():
|
||||
target.run()
|
||||
|
||||
return call_order
|
||||
|
||||
assert record_calls_for_rank(0) == ["run", "barrier", "barrier"]
|
||||
assert record_calls_for_rank(1) == ["barrier", "run", "barrier"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("clip_val", "max_norm"), [(1e-3, None), (None, 1)])
|
||||
def test_grad_clipping(clip_val, max_norm):
|
||||
fabric = Fabric(devices=1)
|
||||
|
|
Loading…
Reference in New Issue