From 78b7a39e72f4c13abf6b9bbe5b49f4414f4ba859 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 6 Feb 2024 22:26:10 +0100 Subject: [PATCH] Update throughput docs (#19415) --- docs/source-fabric/index.rst | 1 + src/lightning/pytorch/callbacks/throughput_monitor.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source-fabric/index.rst b/docs/source-fabric/index.rst index 5051d9c1c0..d0aea2cc8a 100644 --- a/docs/source-fabric/index.rst +++ b/docs/source-fabric/index.rst @@ -211,6 +211,7 @@ Get Started Loggers Precision Strategies + Utilities .. toctree:: diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 6d3cb93022..71a85e431b 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -50,7 +50,7 @@ class ThroughputMonitor(Callback): model = MyModel() def sample_forward(): - batch = torch.randn(...) + batch = torch.randn(..., device="meta") return model(batch) self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)