diff --git a/docs/source-fabric/index.rst b/docs/source-fabric/index.rst index 5051d9c1c0..d0aea2cc8a 100644 --- a/docs/source-fabric/index.rst +++ b/docs/source-fabric/index.rst @@ -211,6 +211,7 @@ Get Started Loggers Precision Strategies + Utilities .. toctree:: diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 6d3cb93022..71a85e431b 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -50,7 +50,7 @@ class ThroughputMonitor(Callback): model = MyModel() def sample_forward(): - batch = torch.randn(...) + batch = torch.randn(..., device="meta") return model(batch) self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)