updated args
This commit is contained in:
parent
7814b2d449
commit
b59af1813b
|
@ -42,7 +42,6 @@ class ExampleModel(RootModule):
|
||||||
# ---------------------
|
# ---------------------
|
||||||
def forward(self, x, a):
|
def forward(self, x, a):
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
x = self.c_d1(x)
|
x = self.c_d1(x)
|
||||||
x = F.tanh(x)
|
x = F.tanh(x)
|
||||||
x = self.c_d1_bn(x)
|
x = self.c_d1_bn(x)
|
||||||
|
@ -63,7 +62,6 @@ class ExampleModel(RootModule):
|
||||||
:param data_batch:
|
:param data_batch:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pdb.set_trace()
|
|
||||||
# forward pass
|
# forward pass
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
|
@ -81,7 +79,6 @@ class ExampleModel(RootModule):
|
||||||
:param data_batch:
|
:param data_batch:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pdb.set_trace()
|
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
|
|
|
@ -188,6 +188,7 @@ class Trainer(TrainerIO):
|
||||||
# RUN VALIDATION STEP
|
# RUN VALIDATION STEP
|
||||||
# -----------------
|
# -----------------
|
||||||
output = model(data_batch, batch_i)
|
output = model(data_batch, batch_i)
|
||||||
|
pdb.set_trace()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
# batch done
|
# batch done
|
||||||
|
|
Loading…
Reference in New Issue