Add typing for utilities/memory.py (#11545)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
DuYicong515 2022-02-02 18:34:05 -08:00 committed by GitHub
parent 72f0e5bfae
commit 0816a1997e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 17 deletions

View File

@ -95,7 +95,6 @@ module = [
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.meta",
]
ignore_errors = "True"

View File

@ -17,6 +17,7 @@ import gc
import os
import shutil
import subprocess
from io import BytesIO
from typing import Any, Dict
import torch
@ -25,20 +26,6 @@ from torch.nn import Module
from pytorch_lightning.utilities.apply_func import apply_to_collection
class _ByteCounter:
"""Accumulate and stores the total bytes of an object."""
def __init__(self) -> None:
self.nbytes: int = 0
def write(self, data: bytes) -> None:
"""Stores the total bytes of the data."""
self.nbytes += len(data)
def flush(self) -> None:
pass
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
"""Detach all tensors in `in_dict`.
@ -183,7 +170,7 @@ def get_model_size_mb(model: Module) -> float:
Returns:
Number of megabytes in the parameters of the input module.
"""
model_size = _ByteCounter()
model_size = BytesIO()
torch.save(model.state_dict(), model_size)
size_mb = model_size.nbytes / 1e6
size_mb = model_size.getbuffer().nbytes / 1e6
return size_mb