lightning/pytorch_lightning/core/memory.py

359 lines
12 KiB
Python
Raw Normal View History

import os
2019-03-31 01:45:16 +00:00
import subprocess
from collections import OrderedDict
from subprocess import PIPE
from typing import Tuple, Dict, Union, List, Any
2019-03-31 01:45:16 +00:00
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
2019-03-31 01:45:16 +00:00
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
UNKNOWN_SIZE = "?"
2019-03-31 01:45:16 +00:00
class LayerSummary(object):
"""
Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
It collects the following information:
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
- Input shape
- Output shape
- Number of parameters
The input and output shapes are only known after the example input array was
passed through the model.
Example::
>>> model = torch.nn.Conv2d(3, 8, 3)
>>> summary = LayerSummary(model)
>>> summary.num_parameters
224
>>> summary.layer_type
'Conv2d'
>>> output = model(torch.rand(1, 3, 5, 5))
>>> summary.in_size
[1, 3, 5, 5]
>>> summary.out_size
[1, 8, 3, 3]
2019-03-31 01:45:16 +00:00
Args:
module: A module to summarize
2019-03-31 01:45:16 +00:00
"""
2019-03-31 01:45:16 +00:00
def __init__(self, module: nn.Module):
super().__init__()
self._module = module
self._hook_handle = self._register_hook()
self._in_size = None
self._out_size = None
2019-03-31 01:45:16 +00:00
def _register_hook(self):
"""
Registers a hook on the module that computes the input- and output size(s)
on the first forward pass. The hook will remove itself from the module, meaning that
recursive models will only record their input- and output shapes once.
"""
def hook(module, inp, out):
if len(inp) == 1:
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
self._hook_handle.remove() # hook detaches itself from module
return self._module.register_forward_hook(hook)
@property
def in_size(self):
return self._in_size or UNKNOWN_SIZE
@property
def out_size(self):
return self._out_size or UNKNOWN_SIZE
@property
def layer_type(self) -> str:
""" Returns the class name of the module. """
return str(self._module.__class__.__name__)
@property
def num_parameters(self) -> int:
""" Returns the number of parameters in this module. """
return sum(np.prod(p.shape) for p in self._module.parameters())
class ModelSummary(object):
"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
Args:
model: The model to summarize (also referred to as the root module)
mode: Can be one of
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
The root module may also have an attribute ``example_input_array`` as shown in the example below.
If present, the root module will be called with it as input to determine the
intermediate input- and output shapes of all layers. Supported are tensors and
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
in the summary table. The summary will also display `?` for layers not used in the forward pass.
Example::
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512))
... self.example_input_array = torch.zeros(10, 256) # optional
...
... def forward(self, x):
... return self.net(x)
...
>>> model = LitModel()
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
"""
MODE_TOP = "top"
MODE_FULL = "full"
MODE_DEFAULT = MODE_TOP
MODES = [MODE_FULL, MODE_TOP]
def __init__(self, model: "pl.LightningModule", mode: str = MODE_DEFAULT):
self._model = model
self._mode = mode
self._layer_summary = self.summarize()
2019-03-31 01:45:16 +00:00
@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
if self._mode == ModelSummary.MODE_FULL:
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
elif self._mode == ModelSummary.MODE_TOP:
# the children are the top-level modules
mods = self._model.named_children()
else:
mods = []
return list(mods)
@property
def layer_names(self) -> List[str]:
return list(self._layer_summary.keys())
2019-07-24 20:24:58 +00:00
@property
def layer_types(self) -> List[str]:
return [layer.layer_type for layer in self._layer_summary.values()]
2019-07-24 20:24:58 +00:00
@property
def in_sizes(self) -> List:
return [layer.in_size for layer in self._layer_summary.values()]
@property
def out_sizes(self) -> List:
return [layer.out_size for layer in self._layer_summary.values()]
@property
def param_nums(self) -> List[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]
2019-07-24 20:27:16 +00:00
def summarize(self) -> Dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
if self._model.example_input_array is not None:
self._forward_example_input()
return summary
def _forward_example_input(self) -> None:
""" Run the example input through each layer to get input- and output sizes. """
model = self._model
trainer = self._model.trainer
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))
if trainer is not None and trainer.use_amp:
if model.use_native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)
mode = model.training
model.eval()
2019-07-24 20:28:55 +00:00
with torch.no_grad():
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
model(*input_)
elif isinstance(input_, dict):
model(**input_)
else:
model(input_)
model.train(mode) # restore mode of module
2019-07-24 20:27:16 +00:00
def __str__(self):
"""
2019-03-31 01:45:16 +00:00
Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes
"""
arrays = [
[" ", list(map(str, range(len(self._layer_summary))))],
["Name", self.layer_names],
["Type", self.layer_types],
["Params", list(map(get_human_readable_count, self.param_nums))],
]
if self._model.example_input_array is not None:
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])
return _format_summary_table(*arrays)
def __repr__(self):
return str(self)
2019-03-31 01:45:16 +00:00
def parse_batch_shape(batch: Any) -> Union[str, List]:
if hasattr(batch, "shape"):
return list(batch.shape)
2019-07-24 20:19:19 +00:00
if isinstance(batch, (list, tuple)):
shape = [parse_batch_shape(el) for el in batch]
return shape
return UNKNOWN_SIZE
2019-03-31 01:45:16 +00:00
def _format_summary_table(*cols) -> str:
"""
Takes in a number of arrays, each specifying a column in
the summary table, and combines them all into one big
string defining the summary table that are nicely formatted.
"""
n_rows = len(cols[0][1])
n_cols = 1 + len(cols)
# Get formatting width of each column
col_widths = []
for c in cols:
col_width = max(len(str(a)) for a in c[1]) if n_rows else 0
col_width = max(col_width, len(c[0])) # minimum length is header length
col_widths.append(col_width)
# Formatting
s = "{:<{}}"
total_width = sum(col_widths) + 3 * n_cols
header = [s.format(c[0], l) for c, l in zip(cols, col_widths)]
# Summary = header + divider + Rest of table
summary = " | ".join(header) + "\n" + "-" * total_width
for i in range(n_rows):
line = []
for c, l in zip(cols, col_widths):
line.append(s.format(str(c[1][i]), l))
summary += "\n" + " | ".join(line)
return summary
def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
""" Get a profile of the current memory usage.
Args:
mode: There are two modes:
- 'all' means return memory for all gpus
- 'min_max' means return memory for max and min
Return:
A dictionary in which the keys are device ids as integers and
values are memory usage as integers in MB.
If mode is 'min_max', the dictionary will also contain two additional keys:
- 'min_gpu_mem': the minimum memory usage in MB
- 'max_gpu_mem': the maximum memory usage in MB
"""
memory_map = get_gpu_memory_map()
if mode == "min_max":
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory}
return memory_map
def get_gpu_memory_map() -> Dict[str, int]:
2019-03-31 01:45:16 +00:00
"""Get the current gpu usage.
Return:
A dictionary in which the keys are device ids as integers and
values are memory usage as integers in MB.
2019-03-31 01:45:16 +00:00
"""
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader",],
encoding="utf-8",
# capture_output=True, # valid for python version >=3.7
stdout=PIPE,
stderr=PIPE, # for backward compatibility with python version 3.6
check=True,
)
2019-03-31 01:45:16 +00:00
# Convert lines into a dictionary
gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)]
gpu_memory_map = {f"gpu_{index}": memory for index, memory in enumerate(gpu_memory)}
2019-03-31 01:45:16 +00:00
return gpu_memory_map
def get_human_readable_count(number: int) -> str:
"""
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
Examples:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
index = num_groups - 1
return f"{int(number):,d} {labels[index]}"