added sample input for summary

This commit is contained in:
William Falcon 2019-07-24 16:27:16 -04:00
parent b8cc62ee52
commit 77a7f3e33e
2 changed files with 28 additions and 23 deletions

View File

@ -28,7 +28,8 @@ class LightningTemplateModel(LightningModule):
self.batch_size = hparams.batch_size
self.example_input_array = torch.rand(5, 3 * 28 * 28)
# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)
# build model
self.__build_model()

View File

@ -38,33 +38,37 @@ class ModelSummary(object):
if self.model.on_gpu:
input_ = input_.cuda(0)
if self.model.trainer.use_amp:
input_ = input_.half()
for i in range(1, len(mods)):
m = mods[i]
if type(input_) is list or type(input_) is tuple:
out = m(*input_)
else:
out = m(input_)
with torch.no_grad:
if type(input_) is tuple or type(input_) is list:
in_size = []
for x in input_:
if type(x) is list:
in_size.append(len(x))
else:
in_size.append(x.size())
else:
in_size = np.array(input_.size())
for i in range(1, len(mods)):
m = mods[i]
if type(input_) is list or type(input_) is tuple:
out = m(*input_)
else:
out = m(input_)
in_sizes.append(in_size)
if type(input_) is tuple or type(input_) is list:
in_size = []
for x in input_:
if type(x) is list:
in_size.append(len(x))
else:
in_size.append(x.size())
else:
in_size = np.array(input_.size())
if type(out) is tuple or type(out) is list:
out_size = np.asarray([x.size() for x in out])
else:
out_size = np.array(out.size())
in_sizes.append(in_size)
out_sizes.append(out_size)
input_ = out
if type(out) is tuple or type(out) is list:
out_size = np.asarray([x.size() for x in out])
else:
out_size = np.array(out.size())
out_sizes.append(out_size)
input_ = out
self.in_sizes = in_sizes
self.out_sizes = out_sizes