Refactor model summary + generalize example input array (#1773)

* squash

variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


use hooks


wip


wipp


layer summary


clean up, refactor


type hints


rename


remove obsolete code


rename


unused imports


simplify formatting of table and increase readability


doctest


superclass object


update examples


print unknown sizes


more docs and doctest


testing


unknown layers


add rnn test


remove main


restore train mode


test device wip


device


constant


simplify model forward transfer


return summary object in method


extend tests


fix summary for empty module


extend tests


refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs


move changelog entry to top


remove hardcoded string


simplify


test unknown shapes and all others


comments for tests


fix hparams attribute

* update default

* unused import

* clean up

* replace hardcoded strings

* fix doctest

* fix top/full

* black

* fix rnn test

* fix rnn

* update debugging docs


update docs


typo


update docs


update docs

* add changelog

* extract constant

* setter and getter

* move parity models to test folder

* parameterize mode
This commit is contained in:
Adrian Wälchli 2020-06-15 23:05:58 +02:00 committed by GitHub
parent 22d9464e56
commit 7dc58bd286
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 540 additions and 286 deletions

View File

@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012))
- Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118))
- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773))
## [0.7.6] - 2020-05-16
### Added

View File

@ -1,77 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule
from tests.base.datasets import MNIST
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class ParityModuleRNN(LightningModule):
def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
def forward(self, x):
seq, last = self.rnn(x)
return self.linear_out(seq)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=30)
class ParityModuleMNIST(LightningModule):
def __init__(self):
super().__init__()
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
return x
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(MNIST(train=True, download=True,),
batch_size=128)

View File

@ -5,8 +5,8 @@ import pytest
import torch
import tests.base.utils as tutils
from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST
from pytorch_lightning import Trainer, seed_everything
from tests.base.models import ParityModuleRNN, ParityModuleMNIST
@pytest.mark.parametrize('cls_model,max_diff', [

View File

@ -55,17 +55,32 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
trainer = Trainer(overfit_pct=0.01)
Print the parameter count by layer
----------------------------------
Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule.
To disable this behavior, turn off this flag:
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
Print a summary of your LightningModule
---------------------------------------
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
`'full'` option:
.. testcode::
trainer = Trainer(weights_summary=None)
trainer = Trainer(weights_summary='full')
You can also display the intermediate input- and output sizes of all your layers by setting the
``example_input_array`` attribute in your LightningModule. It will print a table like this
.. code-block:: text
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.
See Also:
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
- :class:`~pytorch_lightning.core.memory.ModelSummary`
Set the number of validation sanity steps

View File

@ -93,6 +93,8 @@ class GAN(LightningModule):
self.validation_z = torch.randn(8, self.latent_dim)
self.example_input_array = torch.zeros(2, hparams.latent_dim)
def forward(self, z):
return self.generator(z)

View File

@ -66,6 +66,8 @@ class LightningTemplateModel(LightningModule):
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
out_features=self.out_features)
self.example_input_array = torch.zeros(2, 1, 28, 28)
def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would

View File

@ -52,7 +52,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
#: Pointer to the logger object
self.logger = None
self.example_input_array = None
#: True if using dp
self.use_dp = False
@ -75,6 +74,17 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
#: device reference
self._device = torch.device('cpu')
# optionally can be set by user
self._example_input_array = None
@property
def example_input_array(self) -> Any:
return self._example_input_array
@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example
@property
def on_gpu(self):
"""
@ -1445,9 +1455,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
will have an argument ``dataset_idx`` which matches the order here.
"""
def summarize(self, mode: str) -> None:
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.info('\n' + model_summary.__str__())
log.info('\n' + str(model_summary))
return model_summary
def freeze(self) -> None:
r"""

View File

@ -1,164 +1,243 @@
"""
Generates a summary of a model's layers and dimensionality
"""
import gc
import os
import subprocess
from collections import OrderedDict
from subprocess import PIPE
from typing import Tuple, Dict, Union, List
from typing import Tuple, Dict, Union, List, Any
import numpy as np
import torch
from torch.nn import Module
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning import _logger as log
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
UNKNOWN_SIZE = "?"
class LayerSummary(object):
"""
Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
It collects the following information:
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
- Input shape
- Output shape
- Number of parameters
The input and output shapes are only known after the example input array was
passed through the model.
Example::
>>> model = torch.nn.Conv2d(3, 8, 3)
>>> summary = LayerSummary(model)
>>> summary.num_parameters
224
>>> summary.layer_type
'Conv2d'
>>> output = model(torch.rand(1, 3, 5, 5))
>>> summary.in_size
[1, 3, 5, 5]
>>> summary.out_size
[1, 8, 3, 3]
Args:
module: A module to summarize
"""
def __init__(self, module: nn.Module):
super().__init__()
self._module = module
self._hook_handle = self._register_hook()
self._in_size = None
self._out_size = None
def _register_hook(self):
"""
Registers a hook on the module that computes the input- and output size(s)
on the first forward pass. The hook will remove itself from the module, meaning that
recursive models will only record their input- and output shapes once.
"""
def hook(module, inp, out):
if len(inp) == 1:
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
self._hook_handle.remove() # hook detaches itself from module
return self._module.register_forward_hook(hook)
@property
def in_size(self):
return self._in_size or UNKNOWN_SIZE
@property
def out_size(self):
return self._out_size or UNKNOWN_SIZE
@property
def layer_type(self) -> str:
""" Returns the class name of the module. """
return str(self._module.__class__.__name__)
@property
def num_parameters(self) -> int:
""" Returns the number of parameters in this module. """
return sum(np.prod(p.shape) for p in self._module.parameters())
class ModelSummary(object):
"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
def __init__(self, model: 'pl.LightningModule', mode: str = 'full'):
""" Generates summaries of model layers and dimensions. """
self.model = model
self.mode = mode
self.in_sizes = []
self.out_sizes = []
Args:
model: The model to summarize (also referred to as the root module)
mode: Can be one of
self.summarize()
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
def __str__(self):
return self.summary.__str__()
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
def __repr__(self):
return self.summary.__str__()
The root module may also have an attribute ``example_input_array`` as shown in the example below.
If present, the root module will be called with it as input to determine the
intermediate input- and output shapes of all layers. Supported are tensors and
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
in the summary table. The summary will also display `?` for layers not used in the forward pass.
def named_modules(self) -> List[Tuple[str, Module]]:
if self.mode == 'full':
mods = self.model.named_modules()
Example::
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512))
... self.example_input_array = torch.zeros(10, 256) # optional
...
... def forward(self, x):
... return self.net(x)
...
>>> model = LitModel()
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
"""
MODE_TOP = "top"
MODE_FULL = "full"
MODE_DEFAULT = MODE_TOP
MODES = [MODE_FULL, MODE_TOP]
def __init__(self, model: "pl.LightningModule", mode: str = MODE_DEFAULT):
self._model = model
self._mode = mode
self._layer_summary = self.summarize()
@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
if self._mode == ModelSummary.MODE_FULL:
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
elif self.mode == 'top':
elif self._mode == ModelSummary.MODE_TOP:
# the children are the top-level modules
mods = self.model.named_children()
mods = self._model.named_children()
else:
mods = []
return list(mods)
def get_variable_sizes(self) -> None:
""" Run sample input through each layer to get output sizes. """
mods = self.named_modules()
in_sizes = []
out_sizes = []
input_ = self.model.example_input_array
@property
def layer_names(self) -> List[str]:
return list(self._layer_summary.keys())
if self.model.on_gpu:
device = next(self.model.parameters()).get_device()
# test if input is a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.cuda(device)
@property
def layer_types(self) -> List[str]:
return [layer.layer_type for layer in self._layer_summary.values()]
if self.model.trainer.use_amp:
# test if it is not a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.half() if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.half()
@property
def in_sizes(self) -> List:
return [layer.in_size for layer in self._layer_summary.values()]
@property
def out_sizes(self) -> List:
return [layer.out_size for layer in self._layer_summary.values()]
@property
def param_nums(self) -> List[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]
def summarize(self) -> Dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
if self._model.example_input_array is not None:
self._forward_example_input()
return summary
def _forward_example_input(self) -> None:
""" Run the example input through each layer to get input- and output sizes. """
model = self._model
trainer = self._model.trainer
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))
if trainer is not None and trainer.use_amp:
if model.use_native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)
mode = model.training
model.eval()
with torch.no_grad():
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
model(*input_)
elif isinstance(input_, dict):
model(**input_)
else:
model(input_)
model.train(mode) # restore mode of module
for _, m in mods:
if isinstance(input_, (list, tuple)): # pragma: no-cover
out = m(*input_)
else:
out = m(input_)
if isinstance(input_, (list, tuple)): # pragma: no-cover
in_size = []
for x in input_:
if isinstance(x, list):
in_size.append(len(x))
else:
in_size.append(x.size())
else:
in_size = np.array(input_.size())
in_sizes.append(in_size)
if isinstance(out, (list, tuple)): # pragma: no-cover
out_size = np.asarray([x.size() for x in out])
else:
out_size = np.array(out.size())
out_sizes.append(out_size)
input_ = out
self.in_sizes = in_sizes
self.out_sizes = out_sizes
assert len(in_sizes) == len(out_sizes)
def get_layer_names(self) -> None:
""" Collect Layer Names """
mods = self.named_modules()
names = []
layers = []
for name, m in mods:
names += [name]
layers += [str(m.__class__)]
layer_types = [x.split('.')[-1][:-2] for x in layers]
self.layer_names = names
self.layer_types = layer_types
def get_parameter_sizes(self) -> None:
""" Get sizes of all parameters in `model`. """
mods = self.named_modules()
sizes = []
for _, m in mods:
p = list(m.parameters())
modsz = [np.array(param.size()) for param in p]
sizes.append(modsz)
self.param_sizes = sizes
def get_parameter_nums(self) -> None:
""" Get number of parameters in each layer. """
param_nums = []
for mod in self.param_sizes:
all_params = 0
for p in mod:
all_params += np.prod(p)
param_nums.append(all_params)
self.param_nums = param_nums
def make_summary(self) -> None:
def __str__(self):
"""
Makes a summary listing with:
Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes
"""
arrays = [['Name', self.layer_names],
['Type', self.layer_types],
['Params', list(map(get_human_readable_count, self.param_nums))]]
if self.model.example_input_array is not None:
arrays.append(['In sizes', self.in_sizes])
arrays.append(['Out sizes', self.out_sizes])
arrays = [
[" ", list(map(str, range(len(self._layer_summary))))],
["Name", self.layer_names],
["Type", self.layer_types],
["Params", list(map(get_human_readable_count, self.param_nums))],
]
if self._model.example_input_array is not None:
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])
self.summary = _format_summary_table(*arrays)
return _format_summary_table(*arrays)
def summarize(self) -> None:
self.get_layer_names()
self.get_parameter_sizes()
self.get_parameter_nums()
def __repr__(self):
return str(self)
if self.model.example_input_array is not None:
self.get_variable_sizes()
self.make_summary()
def parse_batch_shape(batch: Any) -> Union[str, List]:
if hasattr(batch, "shape"):
return list(batch.shape)
if isinstance(batch, (list, tuple)):
shape = [parse_batch_shape(el) for el in batch]
return shape
return UNKNOWN_SIZE
def _format_summary_table(*cols) -> str:
@ -170,68 +249,29 @@ def _format_summary_table(*cols) -> str:
n_rows = len(cols[0][1])
n_cols = 1 + len(cols)
# Layer counter
counter = list(map(str, list(range(n_rows))))
counter_len = max([len(c) for c in counter])
# Get formatting length of each column
length = []
# Get formatting width of each column
col_widths = []
for c in cols:
str_l = len(c[0]) # default length is header length
for a in c[1]:
if isinstance(a, np.ndarray):
array_string = '[' + ', '.join([str(j) for j in a]) + ']'
str_l = max(len(array_string), str_l)
else:
str_l = max(len(a), str_l)
length.append(str_l)
col_width = max(len(str(a)) for a in c[1]) if n_rows else 0
col_width = max(col_width, len(c[0])) # minimum length is header length
col_widths.append(col_width)
# Formatting
s = '{:<{}}'
full_length = sum(length) + 3 * n_cols
header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)]
s = "{:<{}}"
total_width = sum(col_widths) + 3 * n_cols
header = [s.format(c[0], l) for c, l in zip(cols, col_widths)]
# Summary = header + divider + Rest of table
summary = ' | '.join(header) + '\n' + '-' * full_length
summary = " | ".join(header) + "\n" + "-" * total_width
for i in range(n_rows):
line = s.format(counter[i], counter_len)
for c, l in zip(cols, length):
if isinstance(c[1][i], np.ndarray):
array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']'
line += ' | ' + array_string + ' ' * (l - len(array_string))
else:
line += ' | ' + s.format(c[1][i], l)
summary += '\n' + line
line = []
for c, l in zip(cols, col_widths):
line.append(s.format(str(c[1][i]), l))
summary += "\n" + " | ".join(line)
return summary
def print_mem_stack() -> None: # pragma: no-cover
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
log.info(type(obj), obj.size())
except Exception:
pass
def count_mem_items() -> Tuple[int, int]: # pragma: no-cover
num_params = 0
num_tensors = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
obj_type = str(type(obj))
if 'parameter' in obj_type:
num_params += 1
else:
num_tensors += 1
except Exception:
pass
return num_params, num_tensors
def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
""" Get a profile of the current memory usage.
@ -251,11 +291,11 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
"""
memory_map = get_gpu_memory_map()
if mode == 'min_max':
if mode == "min_max":
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
memory_map = {'min_gpu_mem': min_memory, 'max_gpu_mem': max_memory}
memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory}
return memory_map
@ -268,18 +308,16 @@ def get_gpu_memory_map() -> Dict[str, int]:
values are memory usage as integers in MB.
"""
result = subprocess.run(
[
'nvidia-smi',
'--query-gpu=memory.used',
'--format=csv,nounits,noheader',
],
encoding='utf-8',
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader",],
encoding="utf-8",
# capture_output=True, # valid for python version >=3.7
stdout=PIPE, stderr=PIPE, # for backward compatibility with python version 3.6
check=True)
stdout=PIPE,
stderr=PIPE, # for backward compatibility with python version 3.6
check=True,
)
# Convert lines into a dictionary
gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)]
gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)}
gpu_memory_map = {f"gpu_{index}": memory for index, memory in enumerate(gpu_memory)}
return gpu_memory_map
@ -310,11 +348,11 @@ def get_human_readable_count(number: int) -> str:
"""
assert number >= 0
labels = [' ', 'K', 'M', 'B', 'T']
labels = PARAMETER_NUM_UNITS
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
index = num_groups - 1
return f'{int(number):,d} {labels[index]}'
return f"{int(number):,d} {labels[index]}"

View File

@ -9,8 +9,9 @@ import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
@ -110,7 +111,7 @@ class Trainer(
distributed_backend: Optional[str] = None,
precision: int = 32,
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = 'top',
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
@ -945,12 +946,13 @@ class Trainer(
self.register_slurm_signal_handlers()
# print model summary
# TODO: remove self.testing condition because model.summarize() is wiping out the weights
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ['full', 'top']:
if self.weights_summary in ModelSummary.MODES:
ref_model.summarize(mode=self.weights_summary)
else:
raise MisconfigurationException("weights_summary can be None, 'full' or 'top'")
raise MisconfigurationException(
"weights_summary can be None, " + ", ".join(ModelSummary.MODES)
)
# track model now.
# if cluster resets state, the model will update with the saved weights

View File

@ -184,3 +184,19 @@ class TrialMNIST(MNIST):
data, targets = torch.load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tests.base.datasets import TrialMNIST
from tests.base.datasets import TrialMNIST, AverageDataset, MNIST
try:
from test_tube import HyperOptArgumentParser
@ -83,6 +83,8 @@ class TestGAN(LightningModule):
self.generated_imgs = None
self.last_imgs = None
self.example_input_array = torch.rand(2, self.hidden_dim)
def forward(self, z):
return self.generator(z)
@ -154,3 +156,58 @@ class TestGAN(LightningModule):
def train_dataloader(self):
return DataLoader(TrialMNIST(train=True, download=True), batch_size=16)
class ParityModuleRNN(LightningModule):
def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
def forward(self, x):
seq, last = self.rnn(x)
return self.linear_out(seq)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=30)
class ParityModuleMNIST(LightningModule):
def __init__(self):
super().__init__()
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
return x
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(MNIST(train=True, download=True,),
batch_size=128)

0
tests/core/__init__.py Normal file
View File

186
tests/core/test_memory.py Normal file
View File

@ -0,0 +1,186 @@
import pytest
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary
from tests.base.models import ParityModuleRNN
class EmptyModule(LightningModule):
""" A module that has no layers """
def __init__(self):
super().__init__()
self.parameter = torch.rand(3, 3, requires_grad=True)
self.example_input_array = torch.zeros(1, 2, 3, 4, 5)
def forward(self, *args, **kwargs):
return {'loss': self.parameter.sum()}
class UnorderedModel(LightningModule):
""" A model in which the layers not defined in order of execution """
def __init__(self):
super().__init__()
# note: the definition order is intentionally scrambled for this test
self.layer2 = nn.Linear(10, 2)
self.combine = nn.Linear(7, 9)
self.layer1 = nn.Linear(3, 5)
self.relu = nn.ReLU()
# this layer is unused, therefore input-/output shapes are unknown
self.unused = nn.Conv2d(1, 1, 1)
self.example_input_array = (torch.rand(2, 3), torch.rand(2, 10))
def forward(self, x, y):
out1 = self.layer1(x)
out2 = self.layer2(y)
out = self.relu(torch.cat((out1, out2), 1))
out = self.combine(out)
return out
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_empty_model_summary_shapes(mode):
""" Test that the summary works for models that have no submodules. """
model = EmptyModule()
summary = model.summarize(mode=mode)
assert summary.in_sizes == []
assert summary.out_sizes == []
assert summary.param_nums == []
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
@pytest.mark.parametrize(['device', 'dtype'], [
pytest.param(torch.device('cpu'), torch.double),
pytest.param(torch.device('cuda', 0), torch.float),
pytest.param(torch.device('cuda', 0), torch.float16),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_linear_model_summary_shapes(device, dtype, mode):
""" Test that the model summary correctly computes the input- and output shapes. """
model = UnorderedModel().type(dtype).to(device)
model.train()
summary = model.summarize(mode=mode)
assert summary.in_sizes == [
[2, 10], # layer 2
[2, 7], # combine
[2, 3], # layer 1
[2, 7], # relu
UNKNOWN_SIZE,
]
assert summary.out_sizes == [
[2, 2], # layer 2
[2, 9], # combine
[2, 5], # layer 1
[2, 7], # relu
UNKNOWN_SIZE,
]
assert model.training
assert model.dtype == dtype
assert model.device == device
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_rnn_summary_shapes(mode):
""" Test that the model summary works for RNNs. """
model = ParityModuleRNN()
b = 3
t = 5
i = model.rnn.input_size
h = model.rnn.hidden_size
o = model.linear_out.out_features
model.example_input_array = torch.zeros(b, t, 10)
summary = model.summarize(mode=mode)
assert summary.in_sizes == [
[b, t, i], # rnn
[b, t, h], # linear
]
assert summary.out_sizes == [
[[b, t, h], [[1, b, h], [1, b, h]]], # rnn
[b, t, o] # linear
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_summary_parameter_count(mode):
""" Test that the summary counts the number of parameters in every submodule. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
assert summary.param_nums == [
model.layer2.weight.numel() + model.layer2.bias.numel(),
model.combine.weight.numel() + model.combine.bias.numel(),
model.layer1.weight.numel() + model.layer1.bias.numel(),
0, # ReLU
model.unused.weight.numel() + model.unused.bias.numel(),
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_summary_layer_types(mode):
""" Test that the summary displays the layer names correctly. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
assert summary.layer_types == [
'Linear',
'Linear',
'Linear',
'ReLU',
'Conv2d',
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
@pytest.mark.parametrize(['example_input', 'expected_size'], [
pytest.param([], UNKNOWN_SIZE),
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
pytest.param(torch.tensor(0), UNKNOWN_SIZE),
pytest.param(dict(tensor=torch.zeros(1, 2, 3)), UNKNOWN_SIZE),
pytest.param(torch.zeros(2, 3, 4), [2, 3, 4]),
pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]),
pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]),
])
def test_example_input_array_types(example_input, expected_size, mode):
""" Test the types of example inputs supported for display in the summary. """
class DummyModule(nn.Module):
def forward(self, *args, **kwargs):
return None
class DummyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = DummyModule()
# this LightningModule and submodule accept any type of input
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
model = DummyLightningModule()
model.example_input_array = example_input
summary = model.summarize(mode=mode)
assert summary.in_sizes == [expected_size]