Make print_nan_grads print grad (#208)
This seems more useful for debugging.
This commit is contained in:
parent
9f9d38673e
commit
81df2259ef
|
@ -1091,10 +1091,10 @@ class Trainer(TrainerIO):
|
|||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)
|
||||
|
||||
def __print_nan_grads(self):
|
||||
if self.print_nan_grads:
|
||||
model = self.__get_model()
|
||||
for param in model.parameters():
|
||||
print(param.grad.float().sum())
|
||||
model = self.__get_model()
|
||||
for param in model.parameters():
|
||||
if torch.isnan(param.grad.float()).any():
|
||||
print(param, param.grad)
|
||||
|
||||
def __run_tng_batch(self, data_batch, batch_nb):
|
||||
if data_batch is None:
|
||||
|
@ -1137,7 +1137,8 @@ class Trainer(TrainerIO):
|
|||
model_ref.on_after_backward()
|
||||
|
||||
# nan grads
|
||||
self.__print_nan_grads()
|
||||
if self.print_nan_grads:
|
||||
self.__print_nan_grads()
|
||||
|
||||
# track total loss for logging (avoid mem leaks)
|
||||
self.batch_loss_value += loss.item()
|
||||
|
|
Loading…
Reference in New Issue