lightning/tests/tests_fabric/utilities/test_throughput.py

340 lines
12 KiB
Python

from unittest import mock
from unittest.mock import Mock, call
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.utilities.throughput import (
Throughput,
ThroughputMonitor,
_MonotonicWindow,
get_available_flops,
measure_flops,
)
from tests_fabric.test_fabric import BoringModel
def test_measure_flops():
with torch.device("meta"):
model = BoringModel()
x = torch.randn(2, 32)
model_fwd = lambda: model(x)
model_loss = lambda y: y.sum()
fwd_flops = measure_flops(model, model_fwd)
assert isinstance(fwd_flops, int)
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
assert isinstance(fwd_and_bwd_flops, int)
assert fwd_flops < fwd_and_bwd_flops
def test_get_available_flops(xla_available):
with mock.patch("torch.cuda.get_device_name", return_value="NVIDIA H100 PCIe"):
flops = get_available_flops(torch.device("cuda"), torch.bfloat16)
assert flops == 756e12
with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch(
"torch.cuda.get_device_name", return_value="t4"
):
assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
from torch_xla.experimental import tpu
assert isinstance(tpu, Mock)
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
assert flops == 275e12
tpu.get_tpu_env.return_value = {"TYPE": "V1"}
with pytest.warns(match="not found for TPU 'V1'"):
assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
assert flops == 123e12
tpu.reset_mock()
@pytest.mark.parametrize(
"device_name",
[
# Hopper
"h100-nvl", # TODO: switch with `torch.cuda.get_device_name()` result
"h100-hbm3", # TODO: switch with `torch.cuda.get_device_name()` result
"NVIDIA H100 PCIe",
"h100-hbm2e", # TODO: switch with `torch.cuda.get_device_name()` result
# Ada
"NVIDIA GeForce RTX 4090",
"NVIDIA GeForce RTX 4080",
"Tesla L40",
"NVIDIA L4",
# Ampere
"NVIDIA A100 80GB PCIe",
"NVIDIA A100-SXM4-40GB",
"NVIDIA GeForce RTX 3090",
"NVIDIA GeForce RTX 3090 Ti",
"NVIDIA GeForce RTX 3080",
"NVIDIA GeForce RTX 3080 Ti",
"NVIDIA GeForce RTX 3070",
pytest.param("NVIDIA GeForce RTX 3070 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
pytest.param("NVIDIA GeForce RTX 3060", marks=pytest.mark.xfail(raises=AssertionError)),
pytest.param("NVIDIA GeForce RTX 3060 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
pytest.param("NVIDIA GeForce RTX 3050", marks=pytest.mark.xfail(raises=AssertionError)),
pytest.param("NVIDIA GeForce RTX 3050 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
"NVIDIA A6000",
"NVIDIA A40",
"NVIDIA A10G",
# Turing
"NVIDIA GeForce RTX 2080 SUPER",
"NVIDIA GeForce RTX 2080 Ti",
"NVIDIA GeForce RTX 2080",
"NVIDIA GeForce RTX 2070 Super",
"Quadro RTX 5000 with Max-Q Design",
"Tesla T4",
"TITAN RTX",
# Volta
"Tesla V100-SXm2-32GB",
"Tesla V100-PCIE-32GB",
"Tesla V100S-PCIE-32GB",
],
)
@mock.patch("lightning.fabric.accelerators.cuda._is_ampere_or_later", return_value=False)
def test_get_available_flops_cuda_mapping_exists(_, device_name):
"""Tests `get_available_flops` against known device names."""
with mock.patch("lightning.fabric.utilities.throughput.torch.cuda.get_device_name", return_value=device_name):
assert get_available_flops(device=torch.device("cuda"), dtype=torch.float32) is not None
def test_throughput():
# required args only
throughput = Throughput()
throughput.update(time=2.0, batches=1, samples=2)
assert throughput.compute() == {"time": 2.0, "batches": 1, "samples": 2}
# different lengths and samples
with pytest.raises(RuntimeError, match="same number of samples"):
throughput.update(time=2.1, batches=2, samples=3, lengths=4)
# lengths and samples
throughput = Throughput(window_size=2)
throughput.update(time=2, batches=1, samples=2, lengths=4)
throughput.update(time=2.5, batches=2, samples=4, lengths=8)
assert throughput.compute() == {
"time": 2.5,
"batches": 2,
"samples": 4,
"lengths": 8,
"device/batches_per_sec": 2.0,
"device/samples_per_sec": 4.0,
"device/items_per_sec": 8.0,
}
with pytest.raises(ValueError, match="Expected the value to increase"):
throughput.update(time=2.5, batches=3, samples=2, lengths=4)
# flops
throughput = Throughput(available_flops=50, window_size=2)
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
assert throughput.compute() == {
"time": 2,
"batches": 2,
"samples": 4,
"lengths": 20,
"device/batches_per_sec": 1.0,
"device/flops_per_sec": 10.0,
"device/items_per_sec": 10.0,
"device/mfu": 0.2,
"device/samples_per_sec": 2.0,
}
# flops without available
throughput.available_flops = None
throughput.reset()
throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
assert throughput.compute() == {
"time": 2,
"batches": 2,
"samples": 4,
"lengths": 20,
"device/batches_per_sec": 1.0,
"device/flops_per_sec": 10.0,
"device/items_per_sec": 10.0,
"device/samples_per_sec": 2.0,
}
throughput = Throughput(window_size=2)
with pytest.raises(ValueError, match=r"samples.*to be greater or equal than batches"):
throughput.update(time=0, batches=2, samples=1)
throughput = Throughput(window_size=2)
with pytest.raises(ValueError, match=r"lengths.*to be greater or equal than samples"):
throughput.update(time=0, batches=2, samples=2, lengths=1)
def mock_train_loop(monitor):
# simulate lit-gpt style loop
total_lengths = 0
total_t0 = 0.0 # fake times
micro_batch_size = 3
for iter_num in range(1, 6):
# forward + backward + step + zero_grad ...
t1 = iter_num + 0.5
total_lengths += 3 * 2
monitor.update(
time=t1 - total_t0,
batches=iter_num,
samples=iter_num * micro_batch_size,
lengths=total_lengths,
flops=10,
)
monitor.compute_and_log()
def test_throughput_monitor():
logger_mock = Mock()
fabric = Fabric(devices=1, loggers=logger_mock)
with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):
monitor = ThroughputMonitor(fabric, window_size=4, separator="|")
mock_train_loop(monitor)
assert logger_mock.log_metrics.mock_calls == [
call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),
call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),
call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),
call(
metrics={
"time": 4.5,
"batches": 4,
"samples": 12,
"lengths": 24,
"device|batches_per_sec": 1.0,
"device|samples_per_sec": 3.0,
"device|items_per_sec": 6.0,
"device|flops_per_sec": 10.0,
"device|mfu": 0.1,
},
step=3,
),
call(
metrics={
"time": 5.5,
"batches": 5,
"samples": 15,
"lengths": 30,
"device|batches_per_sec": 1.0,
"device|samples_per_sec": 3.0,
"device|items_per_sec": 6.0,
"device|flops_per_sec": 10.0,
"device|mfu": 0.1,
},
step=4,
),
]
def test_throughput_monitor_step():
fabric_mock = Mock()
fabric_mock.world_size = 1
fabric_mock.strategy.precision = Precision()
monitor = ThroughputMonitor(fabric_mock)
# automatic step increase
assert monitor.step == -1
monitor.update(time=0.5, batches=1, samples=3)
metrics = monitor.compute_and_log()
assert metrics == {"time": 0.5, "batches": 1, "samples": 3}
assert monitor.step == 0
# manual step
monitor.update(time=1.5, batches=2, samples=4)
metrics = monitor.compute_and_log(step=5)
assert metrics == {"time": 1.5, "batches": 2, "samples": 4}
assert monitor.step == 5
assert fabric_mock.log_dict.mock_calls == [
call(metrics={"time": 0.5, "batches": 1, "samples": 3}, step=0),
call(metrics={"time": 1.5, "batches": 2, "samples": 4}, step=5),
]
def test_throughput_monitor_world_size():
logger_mock = Mock()
fabric = Fabric(devices=1, loggers=logger_mock)
with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):
monitor = ThroughputMonitor(fabric, window_size=4)
# simulate that there are 2 devices
monitor.world_size = 2
mock_train_loop(monitor)
assert logger_mock.log_metrics.mock_calls == [
call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),
call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),
call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),
call(
metrics={
"time": 4.5,
"batches": 4,
"samples": 12,
"lengths": 24,
"device/batches_per_sec": 1.0,
"device/samples_per_sec": 3.0,
"batches_per_sec": 2.0,
"samples_per_sec": 6.0,
"items_per_sec": 12.0,
"device/items_per_sec": 6.0,
"flops_per_sec": 20.0,
"device/flops_per_sec": 10.0,
"device/mfu": 0.1,
},
step=3,
),
call(
metrics={
"time": 5.5,
"batches": 5,
"samples": 15,
"lengths": 30,
"device/batches_per_sec": 1.0,
"device/samples_per_sec": 3.0,
"batches_per_sec": 2.0,
"samples_per_sec": 6.0,
"items_per_sec": 12.0,
"device/items_per_sec": 6.0,
"flops_per_sec": 20.0,
"device/flops_per_sec": 10.0,
"device/mfu": 0.1,
},
step=4,
),
]
def test_monotonic_window():
w = _MonotonicWindow(maxlen=3)
assert w == []
assert len(w) == 0
w.append(1)
w.append(2)
w.append(3)
assert w == [1, 2, 3]
assert len(w) == 3
assert w[1] == 2
assert w[-2:] == [2, 3]
with pytest.raises(NotImplementedError):
w[1] = 123
with pytest.raises(NotImplementedError):
w[1:2] = [1, 2]
with pytest.raises(ValueError, match="Expected the value to increase"):
w.append(2)
w.clear()
w.append(2)