refactored model tests
This commit is contained in:
parent
ecb68b52f8
commit
c26d200c41
|
@ -54,10 +54,6 @@ class LightningTemplateModel(LightningModule):
|
|||
:param x:
|
||||
:return:
|
||||
"""
|
||||
print('-'*100)
|
||||
print('x: ', x.device)
|
||||
print('model: ', self.c_d1.weight.device)
|
||||
print('-'*100)
|
||||
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
|
|
|
@ -107,7 +107,7 @@ def main():
|
|||
trainer = Trainer(
|
||||
experiment=exp,
|
||||
checkpoint_callback=checkpoint,
|
||||
progress_bar=False,
|
||||
progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='dp',
|
||||
|
|
Loading…
Reference in New Issue