2020-08-20 02:03:22 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-11-05 13:55:44 +00:00
|
|
|
import os
|
2020-09-04 10:02:16 +00:00
|
|
|
import shutil
|
2019-03-31 01:45:16 +00:00
|
|
|
import subprocess
|
2020-06-15 21:05:58 +00:00
|
|
|
from collections import OrderedDict
|
2021-03-15 02:17:42 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import numpy as np
|
2019-10-22 08:32:40 +00:00
|
|
|
import torch
|
2020-06-15 21:05:58 +00:00
|
|
|
import torch.nn as nn
|
2020-06-20 11:38:47 +00:00
|
|
|
from torch.utils.hooks import RemovableHandle
|
2020-03-12 16:47:23 +00:00
|
|
|
|
2021-01-12 10:22:37 +00:00
|
|
|
from pytorch_lightning.utilities import AMPType, DeviceType
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
|
|
|
UNKNOWN_SIZE = "?"
|
2020-03-17 22:44:00 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
class LayerSummary(object):
|
|
|
|
"""
|
|
|
|
Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
|
|
|
It collects the following information:
|
|
|
|
|
|
|
|
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
|
|
|
|
- Input shape
|
|
|
|
- Output shape
|
|
|
|
- Number of parameters
|
|
|
|
|
|
|
|
The input and output shapes are only known after the example input array was
|
|
|
|
passed through the model.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
>>> model = torch.nn.Conv2d(3, 8, 3)
|
|
|
|
>>> summary = LayerSummary(model)
|
|
|
|
>>> summary.num_parameters
|
|
|
|
224
|
|
|
|
>>> summary.layer_type
|
|
|
|
'Conv2d'
|
|
|
|
>>> output = model(torch.rand(1, 3, 5, 5))
|
|
|
|
>>> summary.in_size
|
|
|
|
[1, 3, 5, 5]
|
|
|
|
>>> summary.out_size
|
|
|
|
[1, 8, 3, 3]
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
Args:
|
|
|
|
module: A module to summarize
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def __init__(self, module: nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
self._module = module
|
|
|
|
self._hook_handle = self._register_hook()
|
|
|
|
self._in_size = None
|
|
|
|
self._out_size = None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-20 11:38:47 +00:00
|
|
|
def __del__(self):
|
|
|
|
self.detach_hook()
|
|
|
|
|
2021-03-15 02:17:42 +00:00
|
|
|
def _register_hook(self) -> Optional[RemovableHandle]:
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
2020-06-20 11:38:47 +00:00
|
|
|
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
|
|
|
If the hook is called, it will remove itself from the from the module, meaning that
|
2020-06-15 21:05:58 +00:00
|
|
|
recursive models will only record their input- and output shapes once.
|
2021-03-15 02:17:42 +00:00
|
|
|
Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
|
2020-06-20 11:38:47 +00:00
|
|
|
|
|
|
|
Return:
|
2021-03-15 02:17:42 +00:00
|
|
|
A handle for the installed hook, or ``None`` if registering the hook is not possible.
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def hook(module, inp, out):
|
|
|
|
if len(inp) == 1:
|
|
|
|
inp = inp[0]
|
|
|
|
self._in_size = parse_batch_shape(inp)
|
|
|
|
self._out_size = parse_batch_shape(out)
|
2020-06-20 11:38:47 +00:00
|
|
|
self._hook_handle.remove()
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2021-03-15 02:17:42 +00:00
|
|
|
handle = None
|
|
|
|
if not isinstance(self._module, torch.jit.ScriptModule):
|
|
|
|
handle = self._module.register_forward_hook(hook)
|
|
|
|
return handle
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2020-06-20 11:38:47 +00:00
|
|
|
def detach_hook(self):
|
|
|
|
"""
|
|
|
|
Removes the forward hook if it was not already removed in the forward pass.
|
|
|
|
Will be called after the summary is created.
|
|
|
|
"""
|
|
|
|
if self._hook_handle is not None:
|
|
|
|
self._hook_handle.remove()
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
2020-06-20 11:38:47 +00:00
|
|
|
def in_size(self) -> Union[str, List]:
|
2020-06-15 21:05:58 +00:00
|
|
|
return self._in_size or UNKNOWN_SIZE
|
|
|
|
|
|
|
|
@property
|
2020-06-20 11:38:47 +00:00
|
|
|
def out_size(self) -> Union[str, List]:
|
2020-06-15 21:05:58 +00:00
|
|
|
return self._out_size or UNKNOWN_SIZE
|
|
|
|
|
|
|
|
@property
|
|
|
|
def layer_type(self) -> str:
|
|
|
|
""" Returns the class name of the module. """
|
|
|
|
return str(self._module.__class__.__name__)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def num_parameters(self) -> int:
|
|
|
|
""" Returns the number of parameters in this module. """
|
|
|
|
return sum(np.prod(p.shape) for p in self._module.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
class ModelSummary(object):
|
|
|
|
"""
|
|
|
|
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: The model to summarize (also referred to as the root module)
|
|
|
|
mode: Can be one of
|
|
|
|
|
|
|
|
- `top` (default): only the top-level modules will be recorded (the children of the root module)
|
|
|
|
- `full`: summarizes all layers and their submodules in the root module
|
|
|
|
|
|
|
|
The string representation of this summary prints a table with columns containing
|
|
|
|
the name, type and number of parameters for each layer.
|
|
|
|
|
|
|
|
The root module may also have an attribute ``example_input_array`` as shown in the example below.
|
|
|
|
If present, the root module will be called with it as input to determine the
|
|
|
|
intermediate input- and output shapes of all layers. Supported are tensors and
|
|
|
|
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
|
|
|
|
in the summary table. The summary will also display `?` for layers not used in the forward pass.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
>>> import pytorch_lightning as pl
|
2020-06-15 21:05:58 +00:00
|
|
|
>>> class LitModel(pl.LightningModule):
|
|
|
|
...
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512))
|
|
|
|
... self.example_input_array = torch.zeros(10, 256) # optional
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return self.net(x)
|
|
|
|
...
|
|
|
|
>>> model = LitModel()
|
|
|
|
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
|
|
|
|
| Name | Type | Params | In sizes | Out sizes
|
|
|
|
------------------------------------------------------------
|
|
|
|
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
2020-11-22 06:07:52 +00:00
|
|
|
------------------------------------------------------------
|
|
|
|
132 K Trainable params
|
|
|
|
0 Non-trainable params
|
|
|
|
132 K Total params
|
2021-01-25 08:35:29 +00:00
|
|
|
0.530 Total estimated model params size (MB)
|
2020-06-15 21:05:58 +00:00
|
|
|
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
|
|
|
|
| Name | Type | Params | In sizes | Out sizes
|
|
|
|
--------------------------------------------------------------
|
|
|
|
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
|
|
|
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
|
2020-11-20 22:22:21 +00:00
|
|
|
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
|
2020-11-22 06:07:52 +00:00
|
|
|
--------------------------------------------------------------
|
|
|
|
132 K Trainable params
|
|
|
|
0 Non-trainable params
|
|
|
|
132 K Total params
|
2021-01-25 08:35:29 +00:00
|
|
|
0.530 Total estimated model params size (MB)
|
2020-06-15 21:05:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
MODE_TOP = "top"
|
|
|
|
MODE_FULL = "full"
|
|
|
|
MODE_DEFAULT = MODE_TOP
|
|
|
|
MODES = [MODE_FULL, MODE_TOP]
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
def __init__(self, model, mode: str = MODE_DEFAULT):
|
2020-06-15 21:05:58 +00:00
|
|
|
self._model = model
|
|
|
|
self._mode = mode
|
|
|
|
self._layer_summary = self.summarize()
|
2021-01-26 11:48:54 +00:00
|
|
|
# 1 byte -> 8 bits
|
PoC: Accelerator refactor (#5743)
* restoring the result from subprocess
* fix queue.get() order for results
* add missing "block_backward_sync" context manager
* add missing "block_backward_sync" context manager
* fix sync_batchnorm
* fix supported gpu-ids for tuple
* fix clip gradients and inf recursion
* accelerator selection: added cluster_environment plugin
* fix torchelastic test
* fix reduce early stopping decision for DDP
* fix tests: callbacks, conversion to lightning optimizer
* fix lightning optimizer does not pickle
* fix setting benchmark and deterministic option
* fix slurm amp test
* fix prepare_data test and determine node_rank
* fix retrieving last path when testing
* remove obsolete plugin argument
* fix test: test_trainer_config
* fix torchscript tests
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* suggestion from code review
* suggestion from code review
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
|
|
|
# TODO: how do we compute precisin_megabytes in case of mixed precision?
|
|
|
|
precision = self._model.precision if isinstance(self._model.precision, int) else 32
|
|
|
|
self._precision_megabytes = (precision / 8.0) * 1e-6
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def named_modules(self) -> List[Tuple[str, nn.Module]]:
|
|
|
|
if self._mode == ModelSummary.MODE_FULL:
|
|
|
|
mods = self._model.named_modules()
|
2019-10-08 19:30:06 +00:00
|
|
|
mods = list(mods)[1:] # do not include root module (LightningModule)
|
2020-06-15 21:05:58 +00:00
|
|
|
elif self._mode == ModelSummary.MODE_TOP:
|
2019-10-08 19:30:06 +00:00
|
|
|
# the children are the top-level modules
|
2020-06-15 21:05:58 +00:00
|
|
|
mods = self._model.named_children()
|
2019-10-08 19:30:06 +00:00
|
|
|
else:
|
|
|
|
mods = []
|
|
|
|
return list(mods)
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def layer_names(self) -> List[str]:
|
|
|
|
return list(self._layer_summary.keys())
|
2019-07-24 20:24:58 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def layer_types(self) -> List[str]:
|
|
|
|
return [layer.layer_type for layer in self._layer_summary.values()]
|
2019-07-24 20:24:58 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def in_sizes(self) -> List:
|
|
|
|
return [layer.in_size for layer in self._layer_summary.values()]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def out_sizes(self) -> List:
|
|
|
|
return [layer.out_size for layer in self._layer_summary.values()]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def param_nums(self) -> List[int]:
|
|
|
|
return [layer.num_parameters for layer in self._layer_summary.values()]
|
2019-07-24 20:27:16 +00:00
|
|
|
|
2021-01-25 08:35:29 +00:00
|
|
|
@property
|
|
|
|
def total_parameters(self) -> int:
|
|
|
|
return sum(p.numel() for p in self._model.parameters())
|
|
|
|
|
|
|
|
@property
|
|
|
|
def trainable_parameters(self) -> int:
|
|
|
|
return sum(p.numel() for p in self._model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def model_size(self) -> float:
|
2021-02-11 12:04:57 +00:00
|
|
|
# todo: seems it does not work with quantized models - it returns 0.0
|
2021-01-25 08:35:29 +00:00
|
|
|
return self.total_parameters * self._precision_megabytes
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def summarize(self) -> Dict[str, LayerSummary]:
|
|
|
|
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
|
|
|
|
if self._model.example_input_array is not None:
|
|
|
|
self._forward_example_input()
|
2020-06-20 11:38:47 +00:00
|
|
|
for layer in summary.values():
|
|
|
|
layer.detach_hook()
|
2020-06-15 21:05:58 +00:00
|
|
|
return summary
|
|
|
|
|
|
|
|
def _forward_example_input(self) -> None:
|
|
|
|
""" Run the example input through each layer to get input- and output sizes. """
|
|
|
|
model = self._model
|
|
|
|
trainer = self._model.trainer
|
|
|
|
|
|
|
|
input_ = model.example_input_array
|
2021-02-18 11:58:12 +00:00
|
|
|
input_ = model._apply_batch_transfer_handler(input_, model.device)
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2021-01-12 10:22:37 +00:00
|
|
|
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:
|
2020-08-20 11:45:22 +00:00
|
|
|
model.forward = torch.cuda.amp.autocast()(model.forward)
|
2020-06-15 21:05:58 +00:00
|
|
|
|
|
|
|
mode = model.training
|
|
|
|
model.eval()
|
2019-07-24 20:28:55 +00:00
|
|
|
with torch.no_grad():
|
2020-06-15 21:05:58 +00:00
|
|
|
# let the model hooks collect the input- and output shapes
|
|
|
|
if isinstance(input_, (list, tuple)):
|
|
|
|
model(*input_)
|
|
|
|
elif isinstance(input_, dict):
|
|
|
|
model(**input_)
|
|
|
|
else:
|
|
|
|
model(input_)
|
|
|
|
model.train(mode) # restore mode of module
|
2019-07-24 20:27:16 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def __str__(self):
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2019-03-31 01:45:16 +00:00
|
|
|
Makes a summary listing with:
|
|
|
|
|
2021-01-25 08:35:29 +00:00
|
|
|
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2020-06-15 21:05:58 +00:00
|
|
|
arrays = [
|
|
|
|
[" ", list(map(str, range(len(self._layer_summary))))],
|
|
|
|
["Name", self.layer_names],
|
|
|
|
["Type", self.layer_types],
|
|
|
|
["Params", list(map(get_human_readable_count, self.param_nums))],
|
|
|
|
]
|
|
|
|
if self._model.example_input_array is not None:
|
|
|
|
arrays.append(["In sizes", self.in_sizes])
|
|
|
|
arrays.append(["Out sizes", self.out_sizes])
|
2021-01-25 08:35:29 +00:00
|
|
|
total_parameters = self.total_parameters
|
|
|
|
trainable_parameters = self.trainable_parameters
|
|
|
|
model_size = self.model_size
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2021-01-25 08:35:29 +00:00
|
|
|
return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)
|
2020-06-15 21:05:58 +00:00
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return str(self)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
def parse_batch_shape(batch: Any) -> Union[str, List]:
|
|
|
|
if hasattr(batch, "shape"):
|
|
|
|
return list(batch.shape)
|
2019-07-24 20:19:19 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
if isinstance(batch, (list, tuple)):
|
|
|
|
shape = [parse_batch_shape(el) for el in batch]
|
|
|
|
return shape
|
|
|
|
|
|
|
|
return UNKNOWN_SIZE
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
|
2021-01-25 08:35:29 +00:00
|
|
|
def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str:
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2020-01-29 19:52:23 +00:00
|
|
|
Takes in a number of arrays, each specifying a column in
|
|
|
|
the summary table, and combines them all into one big
|
|
|
|
string defining the summary table that are nicely formatted.
|
2020-03-12 16:47:23 +00:00
|
|
|
"""
|
2020-01-29 19:52:23 +00:00
|
|
|
n_rows = len(cols[0][1])
|
|
|
|
n_cols = 1 + len(cols)
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
# Get formatting width of each column
|
|
|
|
col_widths = []
|
2020-01-29 19:52:23 +00:00
|
|
|
for c in cols:
|
2020-06-15 21:05:58 +00:00
|
|
|
col_width = max(len(str(a)) for a in c[1]) if n_rows else 0
|
|
|
|
col_width = max(col_width, len(c[0])) # minimum length is header length
|
|
|
|
col_widths.append(col_width)
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
# Formatting
|
2020-06-15 21:05:58 +00:00
|
|
|
s = "{:<{}}"
|
|
|
|
total_width = sum(col_widths) + 3 * n_cols
|
|
|
|
header = [s.format(c[0], l) for c, l in zip(cols, col_widths)]
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
# Summary = header + divider + Rest of table
|
2020-06-15 21:05:58 +00:00
|
|
|
summary = " | ".join(header) + "\n" + "-" * total_width
|
2020-01-29 19:52:23 +00:00
|
|
|
for i in range(n_rows):
|
2020-06-15 21:05:58 +00:00
|
|
|
line = []
|
|
|
|
for c, l in zip(cols, col_widths):
|
|
|
|
line.append(s.format(str(c[1][i]), l))
|
|
|
|
summary += "\n" + " | ".join(line)
|
2020-11-22 06:07:52 +00:00
|
|
|
summary += "\n" + "-" * total_width
|
|
|
|
|
|
|
|
summary += "\n" + s.format(get_human_readable_count(trainable_parameters), 10)
|
|
|
|
summary += "Trainable params"
|
|
|
|
summary += "\n" + s.format(get_human_readable_count(total_parameters - trainable_parameters), 10)
|
|
|
|
summary += "Non-trainable params"
|
|
|
|
summary += "\n" + s.format(get_human_readable_count(total_parameters), 10)
|
|
|
|
summary += "Total params"
|
2021-01-25 08:35:29 +00:00
|
|
|
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
|
|
|
|
summary += "Total estimated model params size (MB)"
|
2020-01-29 19:52:23 +00:00
|
|
|
|
|
|
|
return summary
|
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
|
|
|
|
""" Get a profile of the current memory usage.
|
|
|
|
|
2020-04-16 16:04:55 +00:00
|
|
|
Args:
|
|
|
|
mode: There are two modes:
|
|
|
|
|
|
|
|
- 'all' means return memory for all gpus
|
|
|
|
- 'min_max' means return memory for max and min
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A dictionary in which the keys are device ids as integers and
|
|
|
|
values are memory usage as integers in MB.
|
|
|
|
If mode is 'min_max', the dictionary will also contain two additional keys:
|
|
|
|
|
|
|
|
- 'min_gpu_mem': the minimum memory usage in MB
|
|
|
|
- 'max_gpu_mem': the maximum memory usage in MB
|
2019-10-05 15:29:34 +00:00
|
|
|
"""
|
|
|
|
memory_map = get_gpu_memory_map()
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
if mode == "min_max":
|
2019-11-05 13:55:44 +00:00
|
|
|
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
|
|
|
|
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory}
|
2019-10-05 15:29:34 +00:00
|
|
|
|
|
|
|
return memory_map
|
|
|
|
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_gpu_memory_map() -> Dict[str, int]:
|
2020-09-04 10:02:16 +00:00
|
|
|
"""
|
|
|
|
Get the current gpu usage.
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
Return:
|
|
|
|
A dictionary in which the keys are device ids as integers and
|
|
|
|
values are memory usage as integers in MB.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-11-05 13:55:44 +00:00
|
|
|
result = subprocess.run(
|
2020-09-04 10:02:16 +00:00
|
|
|
[shutil.which("nvidia-smi"), "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
|
2020-06-15 21:05:58 +00:00
|
|
|
encoding="utf-8",
|
2020-01-20 19:50:57 +00:00
|
|
|
# capture_output=True, # valid for python version >=3.7
|
2020-09-04 10:02:16 +00:00
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
|
2020-06-15 21:05:58 +00:00
|
|
|
check=True,
|
|
|
|
)
|
2020-09-04 10:02:16 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# Convert lines into a dictionary
|
2020-09-04 10:02:16 +00:00
|
|
|
gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)]
|
2021-02-08 19:29:43 +00:00
|
|
|
gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)}
|
2019-03-31 01:45:16 +00:00
|
|
|
return gpu_memory_map
|
2019-10-08 19:30:06 +00:00
|
|
|
|
2021-01-25 19:31:38 +00:00
|
|
|
|
2021-01-25 08:35:29 +00:00
|
|
|
def get_formatted_model_size(total_model_size: float) -> float:
|
|
|
|
return f"{total_model_size:,.3f}"
|
2019-10-08 19:30:06 +00:00
|
|
|
|
2021-01-25 19:31:38 +00:00
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def get_human_readable_count(number: int) -> str:
|
2019-10-08 19:30:06 +00:00
|
|
|
"""
|
|
|
|
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
|
|
|
billions and trillions, respectively.
|
2020-02-27 21:07:51 +00:00
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
Examples:
|
2020-04-16 16:04:55 +00:00
|
|
|
>>> get_human_readable_count(123)
|
|
|
|
'123 '
|
|
|
|
>>> get_human_readable_count(1234) # (one thousand)
|
2020-11-20 22:22:21 +00:00
|
|
|
'1.2 K'
|
2020-04-16 16:04:55 +00:00
|
|
|
>>> get_human_readable_count(2e6) # (two million)
|
2020-11-20 22:22:21 +00:00
|
|
|
'2.0 M'
|
2020-04-16 16:04:55 +00:00
|
|
|
>>> get_human_readable_count(3e9) # (three billion)
|
2020-11-20 22:22:21 +00:00
|
|
|
'3.0 B'
|
|
|
|
>>> get_human_readable_count(4e14) # (four hundred trillion)
|
|
|
|
'400 T'
|
2020-04-16 16:04:55 +00:00
|
|
|
>>> get_human_readable_count(5e15) # (more than trillion)
|
|
|
|
'5,000 T'
|
|
|
|
|
|
|
|
Args:
|
|
|
|
number: a positive integer number
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A string formatted according to the pattern described above.
|
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
"""
|
|
|
|
assert number >= 0
|
2020-06-15 21:05:58 +00:00
|
|
|
labels = PARAMETER_NUM_UNITS
|
2019-10-08 19:30:06 +00:00
|
|
|
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
|
|
|
num_groups = int(np.ceil(num_digits / 3))
|
|
|
|
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
|
|
|
shift = -3 * (num_groups - 1)
|
2021-02-08 19:29:43 +00:00
|
|
|
number = number * (10**shift)
|
2019-10-08 19:30:06 +00:00
|
|
|
index = num_groups - 1
|
2020-11-20 22:22:21 +00:00
|
|
|
if index < 1 or number >= 100:
|
|
|
|
return f"{int(number):,d} {labels[index]}"
|
2021-02-18 11:58:12 +00:00
|
|
|
|
|
|
|
return f"{number:,.1f} {labels[index]}"
|