diff --git a/src/pytorch_lightning/utilities/model_summary.py b/src/pytorch_lightning/utilities/model_summary.py index 2f246c7a66..e19cff40a2 100644 --- a/src/pytorch_lightning/utilities/model_summary.py +++ b/src/pytorch_lightning/utilities/model_summary.py @@ -15,7 +15,7 @@ import contextlib import logging from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -120,7 +120,9 @@ class LayerSummary: @property def num_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum( + cast(int, np.prod(p.shape)) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters() + ) class ModelSummary: