Fix mypy errors for model summary utilities (#13384)

This commit is contained in:
Adrian Wälchli 2022-06-23 16:21:55 +02:00 committed by GitHub
parent 511f1a6515
commit 7a3509decb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -15,7 +15,7 @@
import contextlib
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -120,7 +120,9 @@ class LayerSummary:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(
cast(int, np.prod(p.shape)) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()
)
class ModelSummary: