2019-11-05 13:55:44 +00:00
|
|
|
import os
|
2019-03-31 01:45:16 +00:00
|
|
|
import subprocess
|
2020-06-15 21:05:58 +00:00
|
|
|
from collections import OrderedDict
|
2020-01-20 19:50:57 +00:00
|
|
|
from subprocess import PIPE
|
2020-06-15 21:05:58 +00:00
|
|
|
from typing import Tuple, Dict, Union, List, Any
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import numpy as np
|
2019-10-22 08:32:40 +00:00
|
|
|
import torch
|
2020-06-15 21:05:58 +00:00
|
|
|
import torch.nn as nn
|
2020-06-20 11:38:47 +00:00
|
|
|
from torch.utils.hooks import RemovableHandle
|
2020-03-12 16:47:23 +00:00
|
|
|
|
2020-08-08 09:07:32 +00:00
|
|
|
from pytorch_lightning.utilities import AMPType
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
|
|
|
UNKNOWN_SIZE = "?"
|
2020-03-17 22:44:00 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
class LayerSummary(object):
|
|
|
|
"""
|
|
|
|
Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
|
|
|
It collects the following information:
|
|
|
|
|
|
|
|
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
|
|
|
|
- Input shape
|
|
|
|
- Output shape
|
|
|
|
- Number of parameters
|
|
|
|
|
|
|
|
The input and output shapes are only known after the example input array was
|
|
|
|
passed through the model.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
>>> model = torch.nn.Conv2d(3, 8, 3)
|
|
|
|
>>> summary = LayerSummary(model)
|
|
|
|
>>> summary.num_parameters
|
|
|
|
224
|
|
|
|
>>> summary.layer_type
|
|
|
|
'Conv2d'
|
|
|
|
>>> output = model(torch.rand(1, 3, 5, 5))
|
|
|
|
>>> summary.in_size
|
|
|
|
[1, 3, 5, 5]
|
|
|
|
>>> summary.out_size
|
|
|
|
[1, 8, 3, 3]
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
Args:
|
|
|
|
module: A module to summarize
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def __init__(self, module: nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
self._module = module
|
|
|
|
self._hook_handle = self._register_hook()
|
|
|
|
self._in_size = None
|
|
|
|
self._out_size = None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-20 11:38:47 +00:00
|
|
|
def __del__(self):
|
|
|
|
self.detach_hook()
|
|
|
|
|
|
|
|
def _register_hook(self) -> RemovableHandle:
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
2020-06-20 11:38:47 +00:00
|
|
|
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
|
|
|
If the hook is called, it will remove itself from the from the module, meaning that
|
2020-06-15 21:05:58 +00:00
|
|
|
recursive models will only record their input- and output shapes once.
|
2020-06-20 11:38:47 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
A handle for the installed hook.
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def hook(module, inp, out):
|
|
|
|
if len(inp) == 1:
|
|
|
|
inp = inp[0]
|
|
|
|
self._in_size = parse_batch_shape(inp)
|
|
|
|
self._out_size = parse_batch_shape(out)
|
2020-06-20 11:38:47 +00:00
|
|
|
self._hook_handle.remove()
|
2020-06-15 21:05:58 +00:00
|
|
|
|
|
|
|
return self._module.register_forward_hook(hook)
|
|
|
|
|
2020-06-20 11:38:47 +00:00
|
|
|
def detach_hook(self):
|
|
|
|
"""
|
|
|
|
Removes the forward hook if it was not already removed in the forward pass.
|
|
|
|
Will be called after the summary is created.
|
|
|
|
"""
|
|
|
|
if self._hook_handle is not None:
|
|
|
|
self._hook_handle.remove()
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
2020-06-20 11:38:47 +00:00
|
|
|
def in_size(self) -> Union[str, List]:
|
2020-06-15 21:05:58 +00:00
|
|
|
return self._in_size or UNKNOWN_SIZE
|
|
|
|
|
|
|
|
@property
|
2020-06-20 11:38:47 +00:00
|
|
|
def out_size(self) -> Union[str, List]:
|
2020-06-15 21:05:58 +00:00
|
|
|
return self._out_size or UNKNOWN_SIZE
|
|
|
|
|
|
|
|
@property
|
|
|
|
def layer_type(self) -> str:
|
|
|
|
""" Returns the class name of the module. """
|
|
|
|
return str(self._module.__class__.__name__)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def num_parameters(self) -> int:
|
|
|
|
""" Returns the number of parameters in this module. """
|
|
|
|
return sum(np.prod(p.shape) for p in self._module.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
class ModelSummary(object):
|
|
|
|
"""
|
|
|
|
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: The model to summarize (also referred to as the root module)
|
|
|
|
mode: Can be one of
|
|
|
|
|
|
|
|
- `top` (default): only the top-level modules will be recorded (the children of the root module)
|
|
|
|
- `full`: summarizes all layers and their submodules in the root module
|
|
|
|
|
|
|
|
The string representation of this summary prints a table with columns containing
|
|
|
|
the name, type and number of parameters for each layer.
|
|
|
|
|
|
|
|
The root module may also have an attribute ``example_input_array`` as shown in the example below.
|
|
|
|
If present, the root module will be called with it as input to determine the
|
|
|
|
intermediate input- and output shapes of all layers. Supported are tensors and
|
|
|
|
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
|
|
|
|
in the summary table. The summary will also display `?` for layers not used in the forward pass.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
>>> import pytorch_lightning as pl
|
2020-06-15 21:05:58 +00:00
|
|
|
>>> class LitModel(pl.LightningModule):
|
|
|
|
...
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512))
|
|
|
|
... self.example_input_array = torch.zeros(10, 256) # optional
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return self.net(x)
|
|
|
|
...
|
|
|
|
>>> model = LitModel()
|
|
|
|
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
|
|
|
|
| Name | Type | Params | In sizes | Out sizes
|
|
|
|
------------------------------------------------------------
|
|
|
|
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
|
|
|
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
|
|
|
|
| Name | Type | Params | In sizes | Out sizes
|
|
|
|
--------------------------------------------------------------
|
|
|
|
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
|
|
|
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
|
|
|
|
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
|
|
|
|
"""
|
|
|
|
|
|
|
|
MODE_TOP = "top"
|
|
|
|
MODE_FULL = "full"
|
|
|
|
MODE_DEFAULT = MODE_TOP
|
|
|
|
MODES = [MODE_FULL, MODE_TOP]
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
def __init__(self, model, mode: str = MODE_DEFAULT):
|
2020-06-15 21:05:58 +00:00
|
|
|
self._model = model
|
|
|
|
self._mode = mode
|
|
|
|
self._layer_summary = self.summarize()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def named_modules(self) -> List[Tuple[str, nn.Module]]:
|
|
|
|
if self._mode == ModelSummary.MODE_FULL:
|
|
|
|
mods = self._model.named_modules()
|
2019-10-08 19:30:06 +00:00
|
|
|
mods = list(mods)[1:] # do not include root module (LightningModule)
|
2020-06-15 21:05:58 +00:00
|
|
|
elif self._mode == ModelSummary.MODE_TOP:
|
2019-10-08 19:30:06 +00:00
|
|
|
# the children are the top-level modules
|
2020-06-15 21:05:58 +00:00
|
|
|
mods = self._model.named_children()
|
2019-10-08 19:30:06 +00:00
|
|
|
else:
|
|
|
|
mods = []
|
|
|
|
return list(mods)
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def layer_names(self) -> List[str]:
|
|
|
|
return list(self._layer_summary.keys())
|
2019-07-24 20:24:58 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def layer_types(self) -> List[str]:
|
|
|
|
return [layer.layer_type for layer in self._layer_summary.values()]
|
2019-07-24 20:24:58 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def in_sizes(self) -> List:
|
|
|
|
return [layer.in_size for layer in self._layer_summary.values()]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def out_sizes(self) -> List:
|
|
|
|
return [layer.out_size for layer in self._layer_summary.values()]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def param_nums(self) -> List[int]:
|
|
|
|
return [layer.num_parameters for layer in self._layer_summary.values()]
|
2019-07-24 20:27:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def summarize(self) -> Dict[str, LayerSummary]:
|
|
|
|
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
|
|
|
|
if self._model.example_input_array is not None:
|
|
|
|
self._forward_example_input()
|
2020-06-20 11:38:47 +00:00
|
|
|
for layer in summary.values():
|
|
|
|
layer.detach_hook()
|
2020-06-15 21:05:58 +00:00
|
|
|
return summary
|
|
|
|
|
|
|
|
def _forward_example_input(self) -> None:
|
|
|
|
""" Run the example input through each layer to get input- and output sizes. """
|
|
|
|
model = self._model
|
|
|
|
trainer = self._model.trainer
|
|
|
|
|
|
|
|
input_ = model.example_input_array
|
|
|
|
input_ = model.transfer_batch_to_device(input_, model.device)
|
|
|
|
|
2020-08-08 09:07:32 +00:00
|
|
|
if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu:
|
2020-06-15 21:05:58 +00:00
|
|
|
model.forward = torch.cuda.amp.autocast()(model.forward)
|
|
|
|
|
|
|
|
mode = model.training
|
|
|
|
model.eval()
|
2019-07-24 20:28:55 +00:00
|
|
|
with torch.no_grad():
|
2020-06-15 21:05:58 +00:00
|
|
|
# let the model hooks collect the input- and output shapes
|
|
|
|
if isinstance(input_, (list, tuple)):
|
|
|
|
model(*input_)
|
|
|
|
elif isinstance(input_, dict):
|
|
|
|
model(**input_)
|
|
|
|
else:
|
|
|
|
model(input_)
|
|
|
|
model.train(mode) # restore mode of module
|
2019-07-24 20:27:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def __str__(self):
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
Makes a summary listing with:
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2020-06-15 21:05:58 +00:00
|
|
|
arrays = [
|
|
|
|
[" ", list(map(str, range(len(self._layer_summary))))],
|
|
|
|
["Name", self.layer_names],
|
|
|
|
["Type", self.layer_types],
|
|
|
|
["Params", list(map(get_human_readable_count, self.param_nums))],
|
|
|
|
]
|
|
|
|
if self._model.example_input_array is not None:
|
|
|
|
arrays.append(["In sizes", self.in_sizes])
|
|
|
|
arrays.append(["Out sizes", self.out_sizes])
|
|
|
|
|
|
|
|
return _format_summary_table(*arrays)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return str(self)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def parse_batch_shape(batch: Any) -> Union[str, List]:
|
|
|
|
if hasattr(batch, "shape"):
|
|
|
|
return list(batch.shape)
|
2019-07-24 20:19:19 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
if isinstance(batch, (list, tuple)):
|
|
|
|
shape = [parse_batch_shape(el) for el in batch]
|
|
|
|
return shape
|
|
|
|
|
|
|
|
return UNKNOWN_SIZE
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def _format_summary_table(*cols) -> str:
|
|
|
|
"""
|
2020-01-29 19:52:23 +00:00
|
|
|
Takes in a number of arrays, each specifying a column in
|
|
|
|
the summary table, and combines them all into one big
|
|
|
|
string defining the summary table that are nicely formatted.
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2020-01-29 19:52:23 +00:00
|
|
|
n_rows = len(cols[0][1])
|
|
|
|
n_cols = 1 + len(cols)
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
# Get formatting width of each column
|
|
|
|
col_widths = []
|
2020-01-29 19:52:23 +00:00
|
|
|
for c in cols:
|
2020-06-15 21:05:58 +00:00
|
|
|
col_width = max(len(str(a)) for a in c[1]) if n_rows else 0
|
|
|
|
col_width = max(col_width, len(c[0])) # minimum length is header length
|
|
|
|
col_widths.append(col_width)
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
# Formatting
|
2020-06-15 21:05:58 +00:00
|
|
|
s = "{:<{}}"
|
|
|
|
total_width = sum(col_widths) + 3 * n_cols
|
|
|
|
header = [s.format(c[0], l) for c, l in zip(cols, col_widths)]
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
# Summary = header + divider + Rest of table
|
2020-06-15 21:05:58 +00:00
|
|
|
summary = " | ".join(header) + "\n" + "-" * total_width
|
2020-01-29 19:52:23 +00:00
|
|
|
for i in range(n_rows):
|
2020-06-15 21:05:58 +00:00
|
|
|
line = []
|
|
|
|
for c, l in zip(cols, col_widths):
|
|
|
|
line.append(s.format(str(c[1][i]), l))
|
|
|
|
summary += "\n" + " | ".join(line)
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
|
|
|
|
""" Get a profile of the current memory usage.
|
|
|
|
|
2020-04-16 16:04:55 +00:00
|
|
|
Args:
|
|
|
|
mode: There are two modes:
|
|
|
|
|
|
|
|
- 'all' means return memory for all gpus
|
|
|
|
- 'min_max' means return memory for max and min
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A dictionary in which the keys are device ids as integers and
|
|
|
|
values are memory usage as integers in MB.
|
|
|
|
If mode is 'min_max', the dictionary will also contain two additional keys:
|
|
|
|
|
|
|
|
- 'min_gpu_mem': the minimum memory usage in MB
|
|
|
|
- 'max_gpu_mem': the maximum memory usage in MB
|
2019-10-05 15:29:34 +00:00
|
|
|
"""
|
|
|
|
memory_map = get_gpu_memory_map()
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
if mode == "min_max":
|
2019-11-05 13:55:44 +00:00
|
|
|
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
|
|
|
|
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory}
|
2019-10-05 15:29:34 +00:00
|
|
|
|
|
|
|
return memory_map
|
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_gpu_memory_map() -> Dict[str, int]:
|
2019-03-31 01:45:16 +00:00
|
|
|
"""Get the current gpu usage.
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
Return:
|
|
|
|
A dictionary in which the keys are device ids as integers and
|
|
|
|
values are memory usage as integers in MB.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-11-05 13:55:44 +00:00
|
|
|
result = subprocess.run(
|
2020-06-15 21:05:58 +00:00
|
|
|
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader",],
|
|
|
|
encoding="utf-8",
|
2020-01-20 19:50:57 +00:00
|
|
|
# capture_output=True, # valid for python version >=3.7
|
2020-06-15 21:05:58 +00:00
|
|
|
stdout=PIPE,
|
|
|
|
stderr=PIPE, # for backward compatibility with python version 3.6
|
|
|
|
check=True,
|
|
|
|
)
|
2019-03-31 01:45:16 +00:00
|
|
|
# Convert lines into a dictionary
|
2019-11-05 13:55:44 +00:00
|
|
|
gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)]
|
2020-06-15 21:05:58 +00:00
|
|
|
gpu_memory_map = {f"gpu_{index}": memory for index, memory in enumerate(gpu_memory)}
|
2019-03-31 01:45:16 +00:00
|
|
|
return gpu_memory_map
|
2019-10-08 19:30:06 +00:00
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_human_readable_count(number: int) -> str:
|
2019-10-08 19:30:06 +00:00
|
|
|
"""
|
|
|
|
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
|
|
|
billions and trillions, respectively.
|
2020-02-27 21:07:51 +00:00
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
Examples:
|
2020-04-16 16:04:55 +00:00
|
|
|
>>> get_human_readable_count(123)
|
|
|
|
'123 '
|
|
|
|
>>> get_human_readable_count(1234) # (one thousand)
|
|
|
|
'1 K'
|
|
|
|
>>> get_human_readable_count(2e6) # (two million)
|
|
|
|
'2 M'
|
|
|
|
>>> get_human_readable_count(3e9) # (three billion)
|
|
|
|
'3 B'
|
|
|
|
>>> get_human_readable_count(4e12) # (four trillion)
|
|
|
|
'4 T'
|
|
|
|
>>> get_human_readable_count(5e15) # (more than trillion)
|
|
|
|
'5,000 T'
|
|
|
|
|
|
|
|
Args:
|
|
|
|
number: a positive integer number
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A string formatted according to the pattern described above.
|
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
"""
|
|
|
|
assert number >= 0
|
2020-06-15 21:05:58 +00:00
|
|
|
labels = PARAMETER_NUM_UNITS
|
2019-10-08 19:30:06 +00:00
|
|
|
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
|
|
|
num_groups = int(np.ceil(num_digits / 3))
|
|
|
|
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
|
|
|
shift = -3 * (num_groups - 1)
|
|
|
|
number = number * (10 ** shift)
|
|
|
|
index = num_groups - 1
|
2020-06-15 21:05:58 +00:00
|
|
|
return f"{int(number):,d} {labels[index]}"
|