diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index c340686346..f483c274c3 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -296,7 +296,7 @@ def measure_flops( raise ImportError("`measure_flops` requires PyTorch >= 2.1.") from torch.utils.flop_counter import FlopCounterMode - flop_counter = FlopCounterMode(model, display=False) + flop_counter = FlopCounterMode(display=False) with flop_counter: if loss_fn is None: forward_fn()