Rename Throughput flops argument (#18924)
This commit is contained in:
parent
37cbee42c6
commit
2b6b594dab
|
@ -35,16 +35,17 @@ class Throughput:
|
|||
+------------------------+-------------------------------------------------------------------------------------+
|
||||
| Key | Value |
|
||||
+========================+=====================================================================================+
|
||||
| batches_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of batches |
|
||||
| batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
|
||||
| | processed per second |
|
||||
+--------------------------+-----------------------------------------------------------------------------------+
|
||||
| samples_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of samples |
|
||||
| samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
|
||||
| | processed per second |
|
||||
+--------------------------+-----------------------------------------------------------------------------------+
|
||||
| items_per_sec | Rolling average (over ``window_size`` most recent batches) of the number of items |
|
||||
| items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
|
||||
| | processed per second |
|
||||
+--------------------------+-----------------------------------------------------------------------------------+
|
||||
| flops_per_sec | Estimates flops by flops_per_batch * batches_per_sec |
|
||||
| flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
|
||||
| | processed per second |
|
||||
+--------------------------+-----------------------------------------------------------------------------------+
|
||||
| device/batches_per_sec | batches_per_sec divided by world size |
|
||||
+--------------------------+-----------------------------------------------------------------------------------+
|
||||
|
@ -116,7 +117,7 @@ class Throughput:
|
|||
batches: int,
|
||||
samples: int,
|
||||
lengths: Optional[int] = None,
|
||||
flops_per_batch: Optional[int] = None,
|
||||
flops: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Update throughput metrics.
|
||||
|
||||
|
@ -127,7 +128,8 @@ class Throughput:
|
|||
samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
|
||||
lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
|
||||
each call.
|
||||
flops_per_batch: Flops per batch per device. You can easily compute this by using :func:`measure_flops`.
|
||||
flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using
|
||||
:func:`measure_flops` and multiplying it by the number of batches that have been processed.
|
||||
The value might be different in each device if the batch size is not the same.
|
||||
|
||||
"""
|
||||
|
@ -145,9 +147,9 @@ class Throughput:
|
|||
f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples"
|
||||
f" ({len(self._samples)})"
|
||||
)
|
||||
if flops_per_batch is not None:
|
||||
# sum of flops per batch across ranks
|
||||
self._flops.append(flops_per_batch * self.world_size)
|
||||
if flops is not None:
|
||||
# sum of flops across ranks
|
||||
self._flops.append(flops * self.world_size)
|
||||
|
||||
def compute(self) -> _THROUGHPUT_METRICS:
|
||||
"""Compute throughput metrics."""
|
||||
|
|
|
@ -149,7 +149,7 @@ class ThroughputMonitor(Callback):
|
|||
# this assumes that all iterations used the same batch size
|
||||
samples=iter_num * batch_size,
|
||||
lengths=None if self.length_fn is None else self._lengths[stage],
|
||||
flops_per_batch=flops_per_batch,
|
||||
flops=flops_per_batch,
|
||||
)
|
||||
|
||||
def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
|
||||
|
|
|
@ -90,8 +90,8 @@ def test_throughput():
|
|||
|
||||
# flops
|
||||
throughput = Throughput(available_flops=50, window_size=2)
|
||||
throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10)
|
||||
throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20)
|
||||
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
|
||||
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
|
||||
assert throughput.compute() == {
|
||||
"time": 2,
|
||||
"batches": 2,
|
||||
|
@ -107,8 +107,8 @@ def test_throughput():
|
|||
# flops without available
|
||||
throughput.available_flops = None
|
||||
throughput.reset()
|
||||
throughput.update(time=1, batches=1, samples=2, flops_per_batch=10, lengths=10)
|
||||
throughput.update(time=2, batches=2, samples=4, flops_per_batch=10, lengths=20)
|
||||
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
|
||||
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
|
||||
assert throughput.compute() == {
|
||||
"time": 2,
|
||||
"batches": 2,
|
||||
|
@ -142,7 +142,7 @@ def mock_train_loop(monitor):
|
|||
batches=iter_num,
|
||||
samples=iter_num * micro_batch_size,
|
||||
lengths=total_lengths,
|
||||
flops_per_batch=10,
|
||||
flops=10,
|
||||
)
|
||||
monitor.compute_and_log()
|
||||
|
||||
|
|
Loading…
Reference in New Issue