Rename Throughput flops argument (#18924)

This commit is contained in:
Carlos Mocholí 2023-11-02 16:06:40 +01:00 committed by GitHub
parent 37cbee42c6
commit 2b6b594dab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 15 deletions

View File

@ -35,16 +35,17 @@ class Throughput:
+------------------------+-------------------------------------------------------------------------------------+
| Key | Value |
+========================+=====================================================================================+
| batches_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of batches |
| batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
| | processed per second |
+--------------------------+-----------------------------------------------------------------------------------+
| samples_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of samples |
| samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
| | processed per second |
+--------------------------+-----------------------------------------------------------------------------------+
| items_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of items |
| items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
| | processed per second |
+--------------------------+-----------------------------------------------------------------------------------+
| flops_per_sec | Estimates flops by flops_per_batch * batches_per_sec |
| flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
| | processed per second |
+--------------------------+-----------------------------------------------------------------------------------+
| device/batches_per_sec | batches_per_sec divided by world size |
+--------------------------+-----------------------------------------------------------------------------------+
@ -116,7 +117,7 @@ class Throughput:
batches: int,
samples: int,
lengths: Optional[int] = None,
flops_per_batch: Optional[int] = None,
flops: Optional[int] = None,
) -> None:
"""Update throughput metrics.
@ -127,7 +128,8 @@ class Throughput:
samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
each call.
flops_per_batch: Flops per batch per device. You can easily compute this by using :func:`measure_flops`.
flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using
:func:`measure_flops` and multiplying it by the number of batches that have been processed.
The value might be different in each device if the batch size is not the same.
"""
@ -145,9 +147,9 @@ class Throughput:
f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples"
f" ({len(self._samples)})"
)
if flops_per_batch is not None:
# sum of flops per batch across ranks
self._flops.append(flops_per_batch * self.world_size)
if flops is not None:
# sum of flops across ranks
self._flops.append(flops * self.world_size)
def compute(self) -> _THROUGHPUT_METRICS:
"""Compute throughput metrics."""

View File

@ -149,7 +149,7 @@ class ThroughputMonitor(Callback):
# this assumes that all iterations used the same batch size
samples=iter_num * batch_size,
lengths=None if self.length_fn is None else self._lengths[stage],
flops_per_batch=flops_per_batch,
flops=flops_per_batch,
)
def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:

View File

@ -90,8 +90,8 @@ def test_throughput():
# flops
throughput = Throughput(available_flops=50, window_size=2)
throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20)
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
assert throughput.compute() == {
"time": 2,
"batches": 2,
@ -107,8 +107,8 @@ def test_throughput():
# flops without available
throughput.available_flops = None
throughput.reset()
throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20)
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
assert throughput.compute() == {
"time": 2,
"batches": 2,
@ -142,7 +142,7 @@ def mock_train_loop(monitor):
batches=iter_num,
samples=iter_num * micro_batch_size,
lengths=total_lengths,
flops_per_batch=10,
flops=10,
)
monitor.compute_and_log()