added sample input for summary
This commit is contained in:
parent
b8cc62ee52
commit
77a7f3e33e
|
@ -28,7 +28,8 @@ class LightningTemplateModel(LightningModule):
|
|||
|
||||
self.batch_size = hparams.batch_size
|
||||
|
||||
self.example_input_array = torch.rand(5, 3 * 28 * 28)
|
||||
# if you specify an example input, the summary will show input/output for each layer
|
||||
self.example_input_array = torch.rand(5, 28 * 28)
|
||||
|
||||
# build model
|
||||
self.__build_model()
|
||||
|
|
|
@ -38,33 +38,37 @@ class ModelSummary(object):
|
|||
if self.model.on_gpu:
|
||||
input_ = input_.cuda(0)
|
||||
|
||||
if self.model.trainer.use_amp:
|
||||
input_ = input_.half()
|
||||
|
||||
for i in range(1, len(mods)):
|
||||
m = mods[i]
|
||||
if type(input_) is list or type(input_) is tuple:
|
||||
out = m(*input_)
|
||||
else:
|
||||
out = m(input_)
|
||||
with torch.no_grad:
|
||||
|
||||
if type(input_) is tuple or type(input_) is list:
|
||||
in_size = []
|
||||
for x in input_:
|
||||
if type(x) is list:
|
||||
in_size.append(len(x))
|
||||
else:
|
||||
in_size.append(x.size())
|
||||
else:
|
||||
in_size = np.array(input_.size())
|
||||
for i in range(1, len(mods)):
|
||||
m = mods[i]
|
||||
if type(input_) is list or type(input_) is tuple:
|
||||
out = m(*input_)
|
||||
else:
|
||||
out = m(input_)
|
||||
|
||||
in_sizes.append(in_size)
|
||||
if type(input_) is tuple or type(input_) is list:
|
||||
in_size = []
|
||||
for x in input_:
|
||||
if type(x) is list:
|
||||
in_size.append(len(x))
|
||||
else:
|
||||
in_size.append(x.size())
|
||||
else:
|
||||
in_size = np.array(input_.size())
|
||||
|
||||
if type(out) is tuple or type(out) is list:
|
||||
out_size = np.asarray([x.size() for x in out])
|
||||
else:
|
||||
out_size = np.array(out.size())
|
||||
in_sizes.append(in_size)
|
||||
|
||||
out_sizes.append(out_size)
|
||||
input_ = out
|
||||
if type(out) is tuple or type(out) is list:
|
||||
out_size = np.asarray([x.size() for x in out])
|
||||
else:
|
||||
out_size = np.array(out.size())
|
||||
|
||||
out_sizes.append(out_size)
|
||||
input_ = out
|
||||
|
||||
self.in_sizes = in_sizes
|
||||
self.out_sizes = out_sizes
|
||||
|
|
Loading…
Reference in New Issue