diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index e810c70b34..b80a2b94bd 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -24,7 +24,7 @@ from torch.nn import Module from pytorch_lightning.utilities.apply_func import apply_to_collection -class ByteCounter: +class _ByteCounter: """Accumulate and stores the total bytes of an object.""" def __init__(self) -> None: @@ -184,7 +184,7 @@ def get_model_size_mb(model: Module) -> float: Returns: Number of megabytes in the parameters of the input module. """ - model_size = ByteCounter() + model_size = _ByteCounter() torch.save(model.state_dict(), model_size) size_mb = model_size.nbytes / 1e6 return size_mb