From 2b6b594dabe668752183f5ab8a1d93d0904c4d32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 2 Nov 2023 16:06:40 +0100 Subject: [PATCH] Rename Throughput flops argument (#18924) --- src/lightning/fabric/utilities/throughput.py | 20 ++++++++++--------- .../pytorch/callbacks/throughput_monitor.py | 2 +- .../tests_fabric/utilities/test_throughput.py | 10 +++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 71fa985b23..98ff4b13de 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -35,16 +35,17 @@ class Throughput: +------------------------+-------------------------------------------------------------------------------------+ | Key | Value | +========================+=====================================================================================+ - | batches_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of batches | + | batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches | | | processed per second | +--------------------------+-----------------------------------------------------------------------------------+ - | samples_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of samples | + | samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples | | | processed per second | +--------------------------+-----------------------------------------------------------------------------------+ - | items_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of items | + | items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items | | | processed per second | +--------------------------+-----------------------------------------------------------------------------------+ - | flops_per_sec | Estimates flops by flops_per_batch * batches_per_sec | + | flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops | + | | processed per second | +--------------------------+-----------------------------------------------------------------------------------+ | device/batches_per_sec | batches_per_sec divided by world size | +--------------------------+-----------------------------------------------------------------------------------+ @@ -116,7 +117,7 @@ class Throughput: batches: int, samples: int, lengths: Optional[int] = None, - flops_per_batch: Optional[int] = None, + flops: Optional[int] = None, ) -> None: """Update throughput metrics. @@ -127,7 +128,8 @@ class Throughput: samples: Total samples seen per device. It should monotonically increase by the batch size with each call. lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with each call. - flops_per_batch: Flops per batch per device. You can easily compute this by using :func:`measure_flops`. + flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using + :func:`measure_flops` and multiplying it by the number of batches that have been processed. The value might be different in each device if the batch size is not the same. """ @@ -145,9 +147,9 @@ class Throughput: f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples" f" ({len(self._samples)})" ) - if flops_per_batch is not None: - # sum of flops per batch across ranks - self._flops.append(flops_per_batch * self.world_size) + if flops is not None: + # sum of flops across ranks + self._flops.append(flops * self.world_size) def compute(self) -> _THROUGHPUT_METRICS: """Compute throughput metrics.""" diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 041eb944d9..680d4afa65 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -149,7 +149,7 @@ class ThroughputMonitor(Callback): # this assumes that all iterations used the same batch size samples=iter_num * batch_size, lengths=None if self.length_fn is None else self._lengths[stage], - flops_per_batch=flops_per_batch, + flops=flops_per_batch, ) def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None: diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index ece4dc6a1d..6bf7dbb5f7 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -90,8 +90,8 @@ def test_throughput(): # flops throughput = Throughput(available_flops=50, window_size=2) - throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10) - throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20) + throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10) + throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20) assert throughput.compute() == { "time": 2, "batches": 2, @@ -107,8 +107,8 @@ def test_throughput(): # flops without available throughput.available_flops = None throughput.reset() - throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10) - throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20) + throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10) + throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20) assert throughput.compute() == { "time": 2, "batches": 2, @@ -142,7 +142,7 @@ def mock_train_loop(monitor): batches=iter_num, samples=iter_num * micro_batch_size, lengths=total_lengths, - flops_per_batch=10, + flops=10, ) monitor.compute_and_log()