Fix mypy errors for model summary utilities (#13384)
This commit is contained in:
parent
511f1a6515
commit
7a3509decb
|
@ -15,7 +15,7 @@
|
|||
import contextlib
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -120,7 +120,9 @@ class LayerSummary:
|
|||
@property
|
||||
def num_parameters(self) -> int:
|
||||
"""Returns the number of parameters in this module."""
|
||||
return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
|
||||
return sum(
|
||||
cast(int, np.prod(p.shape)) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()
|
||||
)
|
||||
|
||||
|
||||
class ModelSummary:
|
||||
|
|
Loading…
Reference in New Issue