Changed the model size calculation using `ByteCounter` (#10123)
This commit is contained in:
parent
7ad0ac5509
commit
b77aa718de
|
@ -172,6 +172,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
|
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
|
||||||
- LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227))
|
- LightningModule now raises an error when calling `log(on_step=False, on_epoch=False)` ([#10227](https://github.com/PyTorchLightning/pytorch-lightning/pull/10227))
|
||||||
- Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
|
- Quantization aware training observers are now disabled by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
|
||||||
|
- Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123))
|
||||||
|
|
||||||
|
|
||||||
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))
|
- Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238))
|
||||||
|
|
|
@ -1991,6 +1991,11 @@ class LightningModule(
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_size(self) -> float:
|
def model_size(self) -> float:
|
||||||
|
"""Returns the model size in MegaBytes (MB)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
|
||||||
|
"""
|
||||||
rank_zero_deprecation(
|
rank_zero_deprecation(
|
||||||
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
|
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
|
||||||
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
|
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
|
||||||
|
|
|
@ -16,7 +16,6 @@ import gc
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -25,6 +24,20 @@ from torch.nn import Module
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
|
|
||||||
|
|
||||||
|
class ByteCounter:
|
||||||
|
"""Accumulate and stores the total bytes of an object."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.nbytes: int = 0
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
"""Stores the total bytes of the data."""
|
||||||
|
self.nbytes += len(data)
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
|
def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
|
||||||
"""Detach all tensors in `in_dict`.
|
"""Detach all tensors in `in_dict`.
|
||||||
|
|
||||||
|
@ -163,17 +176,15 @@ def get_gpu_memory_map() -> Dict[str, float]:
|
||||||
|
|
||||||
|
|
||||||
def get_model_size_mb(model: Module) -> float:
|
def get_model_size_mb(model: Module) -> float:
|
||||||
"""Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size.
|
"""Calculates the size of a Module in megabytes.
|
||||||
|
|
||||||
The computation includes everything in the :meth:`~torch.nn.Module.state_dict`,
|
The computation includes everything in the :meth:`~torch.nn.Module.state_dict`,
|
||||||
i.e., by default the parameteters and buffers.
|
i.e., by default the parameters and buffers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of megabytes in the parameters of the input module.
|
Number of megabytes in the parameters of the input module.
|
||||||
"""
|
"""
|
||||||
# TODO: Implement a method without needing to download the model
|
model_size = ByteCounter()
|
||||||
tmp_name = f"{uuid.uuid4().hex}.pt"
|
torch.save(model.state_dict(), model_size)
|
||||||
torch.save(model.state_dict(), tmp_name)
|
size_mb = model_size.nbytes / 1e6
|
||||||
size_mb = os.path.getsize(tmp_name) / 1e6
|
|
||||||
os.remove(tmp_name)
|
|
||||||
return size_mb
|
return size_mb
|
||||||
|
|
Loading…
Reference in New Issue