diff --git a/CHANGELOG.md b/CHANGELOG.md index f5fa80353c..89ab2d7d3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed mistake in parameters' grad norm tracking ([#2012](https://github.com/PyTorchLightning/pytorch-lightning/pull/2012)) - Fixed CPU and hanging GPU crash ([#2118](https://github.com/PyTorchLightning/pytorch-lightning/pull/2118)) +- Fixed an issue with the model summary and `example_input_array` depending on a specific ordering of the submodules in a LightningModule ([#1773](https://github.com/PyTorchLightning/pytorch-lightning/pull/1773)) + ## [0.7.6] - 2020-05-16 ### Added diff --git a/benchmarks/parity_modules.py b/benchmarks/parity_modules.py deleted file mode 100644 index 344debe46a..0000000000 --- a/benchmarks/parity_modules.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader - -from pytorch_lightning import LightningModule -from tests.base.datasets import MNIST - - -class AverageDataset(Dataset): - def __init__(self, dataset_len=300, sequence_len=100): - self.dataset_len = dataset_len - self.sequence_len = sequence_len - self.input_seq = torch.randn(dataset_len, sequence_len, 10) - top, bottom = self.input_seq.chunk(2, -1) - self.output_seq = top + bottom.roll(shifts=1, dims=-1) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, item): - return self.input_seq[item], self.output_seq[item] - - -class ParityModuleRNN(LightningModule): - def __init__(self): - super().__init__() - self.rnn = nn.LSTM(10, 20, batch_first=True) - self.linear_out = nn.Linear(in_features=20, out_features=5) - - def forward(self, x): - seq, last = self.rnn(x) - return self.linear_out(seq) - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.mse_loss(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(AverageDataset(), batch_size=30) - - -class ParityModuleMNIST(LightningModule): - - def __init__(self): - super().__init__() - self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) - self.c_d1_bn = nn.BatchNorm1d(128) - self.c_d1_drop = nn.Dropout(0.3) - self.c_d2 = nn.Linear(in_features=128, out_features=10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = self.c_d1(x) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - x = self.c_d2(x) - return x - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(MNIST(train=True, download=True,), - batch_size=128) diff --git a/benchmarks/test_parity.py b/benchmarks/test_parity.py index 186dc57dac..f38cab56dc 100644 --- a/benchmarks/test_parity.py +++ b/benchmarks/test_parity.py @@ -5,8 +5,8 @@ import pytest import torch import tests.base.utils as tutils -from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST from pytorch_lightning import Trainer, seed_everything +from tests.base.models import ParityModuleRNN, ParityModuleMNIST @pytest.mark.parametrize('cls_model,max_diff', [ diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 412b6d613e..741f94c524 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -55,17 +55,32 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) trainer = Trainer(overfit_pct=0.01) -Print the parameter count by layer ----------------------------------- -Whenever the .fit() function gets called, the Trainer will print the weights summary for the lightningModule. -To disable this behavior, turn off this flag: - -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` -argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) +Print a summary of your LightningModule +--------------------------------------- +Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule. +By default it only prints the top-level modules. If you want to show all submodules in your network, use the +`'full'` option: .. testcode:: - trainer = Trainer(weights_summary=None) + trainer = Trainer(weights_summary='full') + +You can also display the intermediate input- and output sizes of all your layers by setting the +``example_input_array`` attribute in your LightningModule. It will print a table like this + +.. code-block:: text + + | Name | Type | Params | In sizes | Out sizes + -------------------------------------------------------------- + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512] + +when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers. + +See Also: + - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument + - :class:`~pytorch_lightning.core.memory.ModelSummary` Set the number of validation sanity steps diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 4417e5e02c..427312ecaf 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -93,6 +93,8 @@ class GAN(LightningModule): self.validation_z = torch.randn(8, self.latent_dim) + self.example_input_array = torch.zeros(2, hparams.latent_dim) + def forward(self, z): return self.generator(z) diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py index b309094254..dba605f5ca 100644 --- a/pl_examples/models/lightning_template.py +++ b/pl_examples/models/lightning_template.py @@ -66,6 +66,8 @@ class LightningTemplateModel(LightningModule): self.c_d2 = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + self.example_input_array = torch.zeros(2, 1, 28, 28) + def forward(self, x): """ No special modification required for Lightning, define it as you normally would diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 75b08ddc7a..e5c470e3cf 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -52,7 +52,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod #: Pointer to the logger object self.logger = None - self.example_input_array = None #: True if using dp self.use_dp = False @@ -75,6 +74,17 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod #: device reference self._device = torch.device('cpu') + # optionally can be set by user + self._example_input_array = None + + @property + def example_input_array(self) -> Any: + return self._example_input_array + + @example_input_array.setter + def example_input_array(self, example: Any) -> None: + self._example_input_array = example + @property def on_gpu(self): """ @@ -1445,9 +1455,10 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod will have an argument ``dataset_idx`` which matches the order here. """ - def summarize(self, mode: str) -> None: + def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary: model_summary = ModelSummary(self, mode=mode) - log.info('\n' + model_summary.__str__()) + log.info('\n' + str(model_summary)) + return model_summary def freeze(self) -> None: r""" diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 8cbd281d6b..738184b9c0 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -1,164 +1,243 @@ -""" -Generates a summary of a model's layers and dimensionality -""" - -import gc import os import subprocess +from collections import OrderedDict from subprocess import PIPE -from typing import Tuple, Dict, Union, List +from typing import Tuple, Dict, Union, List, Any import numpy as np import torch -from torch.nn import Module +import torch.nn as nn import pytorch_lightning as pl +from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning import _logger as log +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] +UNKNOWN_SIZE = "?" + + +class LayerSummary(object): + """ + Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + It collects the following information: + + - Type of the layer (e.g. Linear, BatchNorm1d, ...) + - Input shape + - Output shape + - Number of parameters + + The input and output shapes are only known after the example input array was + passed through the model. + + Example:: + + >>> model = torch.nn.Conv2d(3, 8, 3) + >>> summary = LayerSummary(model) + >>> summary.num_parameters + 224 + >>> summary.layer_type + 'Conv2d' + >>> output = model(torch.rand(1, 3, 5, 5)) + >>> summary.in_size + [1, 3, 5, 5] + >>> summary.out_size + [1, 8, 3, 3] + + Args: + module: A module to summarize + + """ + + def __init__(self, module: nn.Module): + super().__init__() + self._module = module + self._hook_handle = self._register_hook() + self._in_size = None + self._out_size = None + + def _register_hook(self): + """ + Registers a hook on the module that computes the input- and output size(s) + on the first forward pass. The hook will remove itself from the module, meaning that + recursive models will only record their input- and output shapes once. + """ + + def hook(module, inp, out): + if len(inp) == 1: + inp = inp[0] + self._in_size = parse_batch_shape(inp) + self._out_size = parse_batch_shape(out) + self._hook_handle.remove() # hook detaches itself from module + + return self._module.register_forward_hook(hook) + + @property + def in_size(self): + return self._in_size or UNKNOWN_SIZE + + @property + def out_size(self): + return self._out_size or UNKNOWN_SIZE + + @property + def layer_type(self) -> str: + """ Returns the class name of the module. """ + return str(self._module.__class__.__name__) + + @property + def num_parameters(self) -> int: + """ Returns the number of parameters in this module. """ + return sum(np.prod(p.shape) for p in self._module.parameters()) class ModelSummary(object): + """ + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - def __init__(self, model: 'pl.LightningModule', mode: str = 'full'): - """ Generates summaries of model layers and dimensions. """ - self.model = model - self.mode = mode - self.in_sizes = [] - self.out_sizes = [] + Args: + model: The model to summarize (also referred to as the root module) + mode: Can be one of - self.summarize() + - `top` (default): only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module - def __str__(self): - return self.summary.__str__() + The string representation of this summary prints a table with columns containing + the name, type and number of parameters for each layer. - def __repr__(self): - return self.summary.__str__() + The root module may also have an attribute ``example_input_array`` as shown in the example below. + If present, the root module will be called with it as input to determine the + intermediate input- and output shapes of all layers. Supported are tensors and + nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` + in the summary table. The summary will also display `?` for layers not used in the forward pass. - def named_modules(self) -> List[Tuple[str, Module]]: - if self.mode == 'full': - mods = self.model.named_modules() + Example:: + + >>> class LitModel(pl.LightningModule): + ... + ... def __init__(self): + ... super().__init__() + ... self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512)) + ... self.example_input_array = torch.zeros(10, 256) # optional + ... + ... def forward(self, x): + ... return self.net(x) + ... + >>> model = LitModel() + >>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + ------------------------------------------------------------ + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + >>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE + | Name | Type | Params | In sizes | Out sizes + -------------------------------------------------------------- + 0 | net | Sequential | 132 K | [10, 256] | [10, 512] + 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] + 2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512] + """ + + MODE_TOP = "top" + MODE_FULL = "full" + MODE_DEFAULT = MODE_TOP + MODES = [MODE_FULL, MODE_TOP] + + def __init__(self, model: "pl.LightningModule", mode: str = MODE_DEFAULT): + self._model = model + self._mode = mode + self._layer_summary = self.summarize() + + @property + def named_modules(self) -> List[Tuple[str, nn.Module]]: + if self._mode == ModelSummary.MODE_FULL: + mods = self._model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) - elif self.mode == 'top': + elif self._mode == ModelSummary.MODE_TOP: # the children are the top-level modules - mods = self.model.named_children() + mods = self._model.named_children() else: mods = [] return list(mods) - def get_variable_sizes(self) -> None: - """ Run sample input through each layer to get output sizes. """ - mods = self.named_modules() - in_sizes = [] - out_sizes = [] - input_ = self.model.example_input_array + @property + def layer_names(self) -> List[str]: + return list(self._layer_summary.keys()) - if self.model.on_gpu: - device = next(self.model.parameters()).get_device() - # test if input is a list or a tuple - if isinstance(input_, (list, tuple)): - input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i - for input_i in input_] - else: - input_ = input_.cuda(device) + @property + def layer_types(self) -> List[str]: + return [layer.layer_type for layer in self._layer_summary.values()] - if self.model.trainer.use_amp: - # test if it is not a list or a tuple - if isinstance(input_, (list, tuple)): - input_ = [input_i.half() if torch.is_tensor(input_i) else input_i - for input_i in input_] - else: - input_ = input_.half() + @property + def in_sizes(self) -> List: + return [layer.in_size for layer in self._layer_summary.values()] + @property + def out_sizes(self) -> List: + return [layer.out_size for layer in self._layer_summary.values()] + + @property + def param_nums(self) -> List[int]: + return [layer.num_parameters for layer in self._layer_summary.values()] + + def summarize(self) -> Dict[str, LayerSummary]: + summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) + if self._model.example_input_array is not None: + self._forward_example_input() + return summary + + def _forward_example_input(self) -> None: + """ Run the example input through each layer to get input- and output sizes. """ + model = self._model + trainer = self._model.trainer + + input_ = model.example_input_array + input_ = model.transfer_batch_to_device(input_, model.device) + input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype)) + + if trainer is not None and trainer.use_amp: + if model.use_native_amp: + model.forward = torch.cuda.amp.autocast()(model.forward) + + mode = model.training + model.eval() with torch.no_grad(): + # let the model hooks collect the input- and output shapes + if isinstance(input_, (list, tuple)): + model(*input_) + elif isinstance(input_, dict): + model(**input_) + else: + model(input_) + model.train(mode) # restore mode of module - for _, m in mods: - if isinstance(input_, (list, tuple)): # pragma: no-cover - out = m(*input_) - else: - out = m(input_) - - if isinstance(input_, (list, tuple)): # pragma: no-cover - in_size = [] - for x in input_: - if isinstance(x, list): - in_size.append(len(x)) - else: - in_size.append(x.size()) - else: - in_size = np.array(input_.size()) - - in_sizes.append(in_size) - - if isinstance(out, (list, tuple)): # pragma: no-cover - out_size = np.asarray([x.size() for x in out]) - else: - out_size = np.array(out.size()) - - out_sizes.append(out_size) - input_ = out - - self.in_sizes = in_sizes - self.out_sizes = out_sizes - assert len(in_sizes) == len(out_sizes) - - def get_layer_names(self) -> None: - """ Collect Layer Names """ - mods = self.named_modules() - names = [] - layers = [] - for name, m in mods: - names += [name] - layers += [str(m.__class__)] - - layer_types = [x.split('.')[-1][:-2] for x in layers] - - self.layer_names = names - self.layer_types = layer_types - - def get_parameter_sizes(self) -> None: - """ Get sizes of all parameters in `model`. """ - mods = self.named_modules() - sizes = [] - for _, m in mods: - p = list(m.parameters()) - modsz = [np.array(param.size()) for param in p] - sizes.append(modsz) - - self.param_sizes = sizes - - def get_parameter_nums(self) -> None: - """ Get number of parameters in each layer. """ - param_nums = [] - for mod in self.param_sizes: - all_params = 0 - for p in mod: - all_params += np.prod(p) - param_nums.append(all_params) - self.param_nums = param_nums - - def make_summary(self) -> None: + def __str__(self): """ Makes a summary listing with: - Layer Name, Layer Type, Input Size, Output Size, Number of Parameters + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes """ - arrays = [['Name', self.layer_names], - ['Type', self.layer_types], - ['Params', list(map(get_human_readable_count, self.param_nums))]] - if self.model.example_input_array is not None: - arrays.append(['In sizes', self.in_sizes]) - arrays.append(['Out sizes', self.out_sizes]) + arrays = [ + [" ", list(map(str, range(len(self._layer_summary))))], + ["Name", self.layer_names], + ["Type", self.layer_types], + ["Params", list(map(get_human_readable_count, self.param_nums))], + ] + if self._model.example_input_array is not None: + arrays.append(["In sizes", self.in_sizes]) + arrays.append(["Out sizes", self.out_sizes]) - self.summary = _format_summary_table(*arrays) + return _format_summary_table(*arrays) - def summarize(self) -> None: - self.get_layer_names() - self.get_parameter_sizes() - self.get_parameter_nums() + def __repr__(self): + return str(self) - if self.model.example_input_array is not None: - self.get_variable_sizes() - self.make_summary() + +def parse_batch_shape(batch: Any) -> Union[str, List]: + if hasattr(batch, "shape"): + return list(batch.shape) + + if isinstance(batch, (list, tuple)): + shape = [parse_batch_shape(el) for el in batch] + return shape + + return UNKNOWN_SIZE def _format_summary_table(*cols) -> str: @@ -170,68 +249,29 @@ def _format_summary_table(*cols) -> str: n_rows = len(cols[0][1]) n_cols = 1 + len(cols) - # Layer counter - counter = list(map(str, list(range(n_rows)))) - counter_len = max([len(c) for c in counter]) - - # Get formatting length of each column - length = [] + # Get formatting width of each column + col_widths = [] for c in cols: - str_l = len(c[0]) # default length is header length - for a in c[1]: - if isinstance(a, np.ndarray): - array_string = '[' + ', '.join([str(j) for j in a]) + ']' - str_l = max(len(array_string), str_l) - else: - str_l = max(len(a), str_l) - length.append(str_l) + col_width = max(len(str(a)) for a in c[1]) if n_rows else 0 + col_width = max(col_width, len(c[0])) # minimum length is header length + col_widths.append(col_width) # Formatting - s = '{:<{}}' - full_length = sum(length) + 3 * n_cols - header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)] + s = "{:<{}}" + total_width = sum(col_widths) + 3 * n_cols + header = [s.format(c[0], l) for c, l in zip(cols, col_widths)] # Summary = header + divider + Rest of table - summary = ' | '.join(header) + '\n' + '-' * full_length + summary = " | ".join(header) + "\n" + "-" * total_width for i in range(n_rows): - line = s.format(counter[i], counter_len) - for c, l in zip(cols, length): - if isinstance(c[1][i], np.ndarray): - array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']' - line += ' | ' + array_string + ' ' * (l - len(array_string)) - else: - line += ' | ' + s.format(c[1][i], l) - summary += '\n' + line + line = [] + for c, l in zip(cols, col_widths): + line.append(s.format(str(c[1][i]), l)) + summary += "\n" + " | ".join(line) return summary -def print_mem_stack() -> None: # pragma: no-cover - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): - log.info(type(obj), obj.size()) - except Exception: - pass - - -def count_mem_items() -> Tuple[int, int]: # pragma: no-cover - num_params = 0 - num_tensors = 0 - for obj in gc.get_objects(): - try: - if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): - obj_type = str(type(obj)) - if 'parameter' in obj_type: - num_params += 1 - else: - num_tensors += 1 - except Exception: - pass - - return num_params, num_tensors - - def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: """ Get a profile of the current memory usage. @@ -251,11 +291,11 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: """ memory_map = get_gpu_memory_map() - if mode == 'min_max': + if mode == "min_max": min_index, min_memory = min(memory_map.items(), key=lambda item: item[1]) max_index, max_memory = max(memory_map.items(), key=lambda item: item[1]) - memory_map = {'min_gpu_mem': min_memory, 'max_gpu_mem': max_memory} + memory_map = {"min_gpu_mem": min_memory, "max_gpu_mem": max_memory} return memory_map @@ -268,18 +308,16 @@ def get_gpu_memory_map() -> Dict[str, int]: values are memory usage as integers in MB. """ result = subprocess.run( - [ - 'nvidia-smi', - '--query-gpu=memory.used', - '--format=csv,nounits,noheader', - ], - encoding='utf-8', + ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader",], + encoding="utf-8", # capture_output=True, # valid for python version >=3.7 - stdout=PIPE, stderr=PIPE, # for backward compatibility with python version 3.6 - check=True) + stdout=PIPE, + stderr=PIPE, # for backward compatibility with python version 3.6 + check=True, + ) # Convert lines into a dictionary gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)] - gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)} + gpu_memory_map = {f"gpu_{index}": memory for index, memory in enumerate(gpu_memory)} return gpu_memory_map @@ -310,11 +348,11 @@ def get_human_readable_count(number: int) -> str: """ assert number >= 0 - labels = [' ', 'K', 'M', 'B', 'T'] + labels = PARAMETER_NUM_UNITS num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) num_groups = int(np.ceil(num_digits / 3)) num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions shift = -3 * (num_groups - 1) number = number * (10 ** shift) index = num_groups - 1 - return f'{int(number):,d} {labels[index]}' + return f"{int(number):,d} {labels[index]}" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 36035bbf97..3112cc3054 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -9,8 +9,9 @@ import torch.multiprocessing as mp from torch.utils.data import DataLoader from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin @@ -110,7 +111,7 @@ class Trainer( distributed_backend: Optional[str] = None, precision: int = 32, print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 - weights_summary: Optional[str] = 'top', + weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, @@ -945,12 +946,13 @@ class Trainer( self.register_slurm_signal_handlers() # print model summary - # TODO: remove self.testing condition because model.summarize() is wiping out the weights if self.is_global_zero and self.weights_summary is not None and not self.testing: - if self.weights_summary in ['full', 'top']: + if self.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.weights_summary) else: - raise MisconfigurationException("weights_summary can be None, 'full' or 'top'") + raise MisconfigurationException( + "weights_summary can be None, " + ", ".join(ModelSummary.MODES) + ) # track model now. # if cluster resets state, the model will update with the saved weights diff --git a/tests/base/datasets.py b/tests/base/datasets.py index af6cb062dc..35edbfc948 100644 --- a/tests/base/datasets.py +++ b/tests/base/datasets.py @@ -184,3 +184,19 @@ class TrialMNIST(MNIST): data, targets = torch.load(path_fname) data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits) torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) + + +class AverageDataset(Dataset): + + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] diff --git a/tests/base/models.py b/tests/base/models.py index 77deb0766b..7f295da59f 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from tests.base.datasets import TrialMNIST +from tests.base.datasets import TrialMNIST, AverageDataset, MNIST try: from test_tube import HyperOptArgumentParser @@ -83,6 +83,8 @@ class TestGAN(LightningModule): self.generated_imgs = None self.last_imgs = None + self.example_input_array = torch.rand(2, self.hidden_dim) + def forward(self, z): return self.generator(z) @@ -154,3 +156,58 @@ class TestGAN(LightningModule): def train_dataloader(self): return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) + + +class ParityModuleRNN(LightningModule): + def __init__(self): + super().__init__() + self.rnn = nn.LSTM(10, 20, batch_first=True) + self.linear_out = nn.Linear(in_features=20, out_features=5) + + def forward(self, x): + seq, last = self.rnn(x) + return self.linear_out(seq) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.mse_loss(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=30) + + +class ParityModuleMNIST(LightningModule): + + def __init__(self): + super().__init__() + self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) + self.c_d1_bn = nn.BatchNorm1d(128) + self.c_d1_drop = nn.Dropout(0.3) + self.c_d2 = nn.Linear(in_features=128, out_features=10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + x = self.c_d2(x) + return x + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(MNIST(train=True, download=True,), + batch_size=128) diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py new file mode 100644 index 0000000000..6ddbb18ecf --- /dev/null +++ b/tests/core/test_memory.py @@ -0,0 +1,186 @@ +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning import LightningModule +from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary +from tests.base.models import ParityModuleRNN + + +class EmptyModule(LightningModule): + """ A module that has no layers """ + + def __init__(self): + super().__init__() + self.parameter = torch.rand(3, 3, requires_grad=True) + self.example_input_array = torch.zeros(1, 2, 3, 4, 5) + + def forward(self, *args, **kwargs): + return {'loss': self.parameter.sum()} + + +class UnorderedModel(LightningModule): + """ A model in which the layers not defined in order of execution """ + + def __init__(self): + super().__init__() + # note: the definition order is intentionally scrambled for this test + self.layer2 = nn.Linear(10, 2) + self.combine = nn.Linear(7, 9) + self.layer1 = nn.Linear(3, 5) + self.relu = nn.ReLU() + # this layer is unused, therefore input-/output shapes are unknown + self.unused = nn.Conv2d(1, 1, 1) + + self.example_input_array = (torch.rand(2, 3), torch.rand(2, 10)) + + def forward(self, x, y): + out1 = self.layer1(x) + out2 = self.layer2(y) + out = self.relu(torch.cat((out1, out2), 1)) + out = self.combine(out) + return out + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_empty_model_summary_shapes(mode): + """ Test that the summary works for models that have no submodules. """ + model = EmptyModule() + summary = model.summarize(mode=mode) + assert summary.in_sizes == [] + assert summary.out_sizes == [] + assert summary.param_nums == [] + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +@pytest.mark.parametrize(['device', 'dtype'], [ + pytest.param(torch.device('cpu'), torch.double), + pytest.param(torch.device('cuda', 0), torch.float), + pytest.param(torch.device('cuda', 0), torch.float16), +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +def test_linear_model_summary_shapes(device, dtype, mode): + """ Test that the model summary correctly computes the input- and output shapes. """ + model = UnorderedModel().type(dtype).to(device) + model.train() + summary = model.summarize(mode=mode) + assert summary.in_sizes == [ + [2, 10], # layer 2 + [2, 7], # combine + [2, 3], # layer 1 + [2, 7], # relu + UNKNOWN_SIZE, + ] + assert summary.out_sizes == [ + [2, 2], # layer 2 + [2, 9], # combine + [2, 5], # layer 1 + [2, 7], # relu + UNKNOWN_SIZE, + ] + assert model.training + assert model.dtype == dtype + assert model.device == device + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_rnn_summary_shapes(mode): + """ Test that the model summary works for RNNs. """ + model = ParityModuleRNN() + + b = 3 + t = 5 + i = model.rnn.input_size + h = model.rnn.hidden_size + o = model.linear_out.out_features + + model.example_input_array = torch.zeros(b, t, 10) + + summary = model.summarize(mode=mode) + assert summary.in_sizes == [ + [b, t, i], # rnn + [b, t, h], # linear + ] + assert summary.out_sizes == [ + [[b, t, h], [[1, b, h], [1, b, h]]], # rnn + [b, t, o] # linear + ] + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_summary_parameter_count(mode): + """ Test that the summary counts the number of parameters in every submodule. """ + model = UnorderedModel() + summary = model.summarize(mode=mode) + assert summary.param_nums == [ + model.layer2.weight.numel() + model.layer2.bias.numel(), + model.combine.weight.numel() + model.combine.bias.numel(), + model.layer1.weight.numel() + model.layer1.bias.numel(), + 0, # ReLU + model.unused.weight.numel() + model.unused.bias.numel(), + ] + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_summary_layer_types(mode): + """ Test that the summary displays the layer names correctly. """ + model = UnorderedModel() + summary = model.summarize(mode=mode) + assert summary.layer_types == [ + 'Linear', + 'Linear', + 'Linear', + 'ReLU', + 'Conv2d', + ] + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +@pytest.mark.parametrize(['example_input', 'expected_size'], [ + pytest.param([], UNKNOWN_SIZE), + pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3), + pytest.param(torch.tensor(0), UNKNOWN_SIZE), + pytest.param(dict(tensor=torch.zeros(1, 2, 3)), UNKNOWN_SIZE), + pytest.param(torch.zeros(2, 3, 4), [2, 3, 4]), + pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]), + pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]), +]) +def test_example_input_array_types(example_input, expected_size, mode): + """ Test the types of example inputs supported for display in the summary. """ + + class DummyModule(nn.Module): + def forward(self, *args, **kwargs): + return None + + class DummyLightningModule(LightningModule): + + def __init__(self): + super().__init__() + self.layer = DummyModule() + + # this LightningModule and submodule accept any type of input + def forward(self, *args, **kwargs): + return self.layer(*args, **kwargs) + + model = DummyLightningModule() + model.example_input_array = example_input + summary = model.summarize(mode=mode) + assert summary.in_sizes == [expected_size]