lightning/pytorch_lightning/core/memory.py

302 lines
9.6 KiB
Python
Raw Normal View History

"""
Generates a summary of a model's layers and dimensionality
"""
2019-03-31 01:45:16 +00:00
import gc
import os
2019-03-31 01:45:16 +00:00
import subprocess
from subprocess import PIPE
from typing import Tuple, Dict, Union, List
2019-03-31 01:45:16 +00:00
import numpy as np
import torch
from torch.nn import Module
import pytorch_lightning as pl
2019-03-31 01:45:16 +00:00
from pytorch_lightning import _logger as log
2019-03-31 01:45:16 +00:00
class ModelSummary(object):
def __init__(self, model: 'pl.LightningModule', mode: str = 'full'):
""" Generates summaries of model layers and dimensions. """
2019-03-31 01:45:16 +00:00
self.model = model
self.mode = mode
2019-03-31 01:45:16 +00:00
self.in_sizes = []
self.out_sizes = []
self.summarize()
def __str__(self):
return self.summary.__str__()
def __repr__(self):
return self.summary.__str__()
def named_modules(self) -> List[Tuple[str, Module]]:
if self.mode == 'full':
mods = self.model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
elif self.mode == 'top':
# the children are the top-level modules
mods = self.model.named_children()
else:
mods = []
return list(mods)
def get_variable_sizes(self) -> None:
""" Run sample input through each layer to get output sizes """
mods = self.named_modules()
2019-03-31 01:45:16 +00:00
in_sizes = []
out_sizes = []
2019-07-24 20:22:09 +00:00
input_ = self.model.example_input_array
2019-07-24 20:24:58 +00:00
if self.model.on_gpu:
device = next(self.model.parameters()).get_device()
# test if input is a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.cuda(device)
2019-07-24 20:24:58 +00:00
2019-07-24 20:27:16 +00:00
if self.model.trainer.use_amp:
# test if it is not a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.half() if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.half()
2019-07-24 20:27:16 +00:00
2019-07-24 20:28:55 +00:00
with torch.no_grad():
2019-07-24 20:27:16 +00:00
for _, m in mods:
if isinstance(input_, (list, tuple)): # pragma: no-cover
2019-07-24 20:27:16 +00:00
out = m(*input_)
else:
out = m(input_)
if isinstance(input_, (list, tuple)): # pragma: no-cover
2019-07-24 20:27:16 +00:00
in_size = []
for x in input_:
Resolve some codefactor issues (#756) * remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * Fix default ckpt path when logger exists (#771) * rename logging -> loggers (#767) * move logging >> loggers * add warning * fix tests * logging alias * formatting * formatting * use isinstance for type checks * revert isinstance check back to type broke tests, because bool is actually subclass of int * add more detail to tbptt example (#755) * add more detail to tbptt example * warn user about new arg in training_step Co-authored-by: Vadim Bereznyuk <kuynzereb@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-01 23:44:05 +00:00
if isinstance(x, list):
2019-07-24 20:27:16 +00:00
in_size.append(len(x))
else:
in_size.append(x.size())
else:
in_size = np.array(input_.size())
in_sizes.append(in_size)
if isinstance(out, (list, tuple)): # pragma: no-cover
2019-07-24 20:27:16 +00:00
out_size = np.asarray([x.size() for x in out])
else:
out_size = np.array(out.size())
out_sizes.append(out_size)
input_ = out
2019-03-31 01:45:16 +00:00
self.in_sizes = in_sizes
self.out_sizes = out_sizes
assert len(in_sizes) == len(out_sizes)
2019-03-31 01:45:16 +00:00
def get_layer_names(self) -> None:
""" Collect Layer Names """
mods = self.named_modules()
2019-03-31 01:45:16 +00:00
names = []
layers = []
for name, m in mods:
names += [name]
layers += [str(m.__class__)]
2019-03-31 01:45:16 +00:00
layer_types = [x.split('.')[-1][:-2] for x in layers]
self.layer_names = names
self.layer_types = layer_types
def get_parameter_sizes(self) -> None:
""" Get sizes of all parameters in `model` """
mods = self.named_modules()
2019-03-31 01:45:16 +00:00
sizes = []
for _, m in mods:
2019-03-31 01:45:16 +00:00
p = list(m.parameters())
Resolve some codefactor issues (#756) * remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * remove unnecessary pass statements * use isinstance for type checks * remove unnecessary else/elif after return * remove unnecessary return statements * move doc string to top * merge isinstance calls * remove unnecessary else/elif after raise * use list comprehension * do not use len without comparison * add missing shebang * revert isinstance check back to type broke tests, because bool is actually subclass of int * add missing period to doc string * Fix default ckpt path when logger exists (#771) * rename logging -> loggers (#767) * move logging >> loggers * add warning * fix tests * logging alias * formatting * formatting * use isinstance for type checks * revert isinstance check back to type broke tests, because bool is actually subclass of int * add more detail to tbptt example (#755) * add more detail to tbptt example * warn user about new arg in training_step Co-authored-by: Vadim Bereznyuk <kuynzereb@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-01 23:44:05 +00:00
modsz = [np.array(param.size()) for param in p]
2019-03-31 01:45:16 +00:00
sizes.append(modsz)
self.param_sizes = sizes
def get_parameter_nums(self) -> None:
""" Get number of parameters in each layer """
2019-03-31 01:45:16 +00:00
param_nums = []
for mod in self.param_sizes:
all_params = 0
for p in mod:
all_params += np.prod(p)
param_nums.append(all_params)
self.param_nums = param_nums
def make_summary(self) -> None:
"""
2019-03-31 01:45:16 +00:00
Makes a summary listing with:
Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
"""
arrays = [['Name', self.layer_names],
['Type', self.layer_types],
['Params', list(map(get_human_readable_count, self.param_nums))]]
2019-07-24 20:23:30 +00:00
if self.model.example_input_array is not None:
arrays.append(['In sizes', self.in_sizes])
arrays.append(['Out sizes', self.out_sizes])
2019-03-31 01:45:16 +00:00
self.summary = _format_summary_table(*arrays)
2019-03-31 01:45:16 +00:00
def summarize(self) -> None:
2019-03-31 01:45:16 +00:00
self.get_layer_names()
self.get_parameter_sizes()
self.get_parameter_nums()
2019-07-24 20:19:19 +00:00
2019-07-24 20:23:30 +00:00
if self.model.example_input_array is not None:
2019-07-24 20:19:19 +00:00
self.get_variable_sizes()
2019-03-31 01:45:16 +00:00
self.make_summary()
def _format_summary_table(*cols) -> str:
"""
Takes in a number of arrays, each specifying a column in
the summary table, and combines them all into one big
string defining the summary table that are nicely formatted.
"""
n_rows = len(cols[0][1])
n_cols = 1 + len(cols)
# Layer counter
counter = list(map(str, list(range(n_rows))))
counter_len = max([len(c) for c in counter])
# Get formatting length of each column
length = []
for c in cols:
str_l = len(c[0]) # default length is header length
for a in c[1]:
if isinstance(a, np.ndarray):
array_string = '[' + ', '.join([str(j) for j in a]) + ']'
str_l = max(len(array_string), str_l)
else:
str_l = max(len(a), str_l)
length.append(str_l)
# Formatting
s = '{:<{}}'
full_length = sum(length) + 3 * n_cols
header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)]
# Summary = header + divider + Rest of table
summary = ' | '.join(header) + '\n' + '-' * full_length
for i in range(n_rows):
line = s.format(counter[i], counter_len)
for c, l in zip(cols, length):
if isinstance(c[1][i], np.ndarray):
array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']'
line += ' | ' + array_string + ' ' * (l - len(array_string))
else:
line += ' | ' + s.format(c[1][i], l)
summary += '\n' + line
return summary
def print_mem_stack() -> None: # pragma: no-cover
2019-03-31 01:45:16 +00:00
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
log.info(type(obj), obj.size())
2019-08-05 21:57:39 +00:00
except Exception:
2019-03-31 01:45:16 +00:00
pass
def count_mem_items() -> Tuple[int, int]: # pragma: no-cover
num_params = 0
num_tensors = 0
2019-03-31 01:45:16 +00:00
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
obj_type = str(type(obj))
if 'parameter' in obj_type:
num_params += 1
2019-03-31 01:45:16 +00:00
else:
num_tensors += 1
2019-08-05 21:57:39 +00:00
except Exception:
2019-03-31 01:45:16 +00:00
pass
return num_params, num_tensors
2019-03-31 01:45:16 +00:00
def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
""" Get a profile of the current memory usage.
:param mode: There are two modes:
- 'all' means return memory for all gpus
- 'min_max' means return memory for max and min
:return:
"""
memory_map = get_gpu_memory_map()
if mode == 'min_max':
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
memory_map = {'min_gpu_mem': min_memory, 'max_gpu_mem': max_memory}
return memory_map
def get_gpu_memory_map() -> Dict[str, int]:
2019-03-31 01:45:16 +00:00
"""Get the current gpu usage.
Return:
A dictionary in which the keys are device ids as integers and
values are memory usage as integers in MB.
2019-03-31 01:45:16 +00:00
"""
result = subprocess.run(
2019-03-31 01:45:16 +00:00
[
'nvidia-smi',
'--query-gpu=memory.used',
'--format=csv,nounits,noheader',
],
encoding='utf-8',
# capture_output=True, # valid for python version >=3.7
stdout=PIPE, stderr=PIPE, # for backward compatibility with python version 3.6
check=True)
2019-03-31 01:45:16 +00:00
# Convert lines into a dictionary
gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)]
gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)}
2019-03-31 01:45:16 +00:00
return gpu_memory_map
def get_human_readable_count(number: int) -> str:
"""
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
Examples:
123 -> 123
1234 -> 1 K (one thousand)
2e6 -> 2 M (two million)
3e9 -> 3 B (three billion)
4e12 -> 4 T (four trillion)
5e15 -> 5,000 T
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
:param number: a positive integer number
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
:return: a string formatted according to the pattern described above.
"""
assert number >= 0
labels = [' ', 'K', 'M', 'B', 'T']
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
index = num_groups - 1
return f'{int(number):,d} {labels[index]}'