Changed the model size calculation using `ByteCounter` (#10123)

This commit is contained in:
Rohit Gupta 2021-11-01 22:12:14 +05:30 committed by GitHub
parent 7ad0ac5509
commit b77aa718de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 8 deletions

View File

@ -172,6 +172,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460)) - Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
- LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227)) - LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227))
- Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540)) - Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
- Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123))
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) - Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))

View File

@ -1991,6 +1991,11 @@ class LightningModule(
@property @property
def model_size(self) -> float: def model_size(self) -> float:
"""Returns the model size in MegaBytes (MB)
Note:
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
"""
rank_zero_deprecation( rank_zero_deprecation(
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",

View File

@ -16,7 +16,6 @@ import gc
import os import os
import shutil import shutil
import subprocess import subprocess
import uuid
from typing import Any, Dict from typing import Any, Dict
import torch import torch
@ -25,6 +24,20 @@ from torch.nn import Module
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
class ByteCounter:
"""Accumulate and stores the total bytes of an object."""
def __init__(self) -> None:
self.nbytes: int = 0
def write(self, data: bytes) -> None:
"""Stores the total bytes of the data."""
self.nbytes += len(data)
def flush(self) -> None:
pass
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any: def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
"""Detach all tensors in `in_dict`. """Detach all tensors in `in_dict`.
@ -163,17 +176,15 @@ def get_gpu_memory_map() -> Dict[str, float]:
def get_model_size_mb(model: Module) -> float: def get_model_size_mb(model: Module) -> float:
"""Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size. """Calculates the size of a Module in megabytes.
The computation includes everything in the :meth:`~torch.nn.Module.state_dict`, The computation includes everything in the :meth:`~torch.nn.Module.state_dict`,
i.e., by default the parameteters and buffers. i.e., by default the parameters and buffers.
Returns: Returns:
Number of megabytes in the parameters of the input module. Number of megabytes in the parameters of the input module.
""" """
# TODO: Implement a method without needing to download the model model_size = ByteCounter()
tmp_name = f"{uuid.uuid4().hex}.pt" torch.save(model.state_dict(), model_size)
torch.save(model.state_dict(), tmp_name) size_mb = model_size.nbytes / 1e6
size_mb = os.path.getsize(tmp_name) / 1e6
os.remove(tmp_name)
return size_mb return size_mb