Make print_nan_grads print grad (#208)

This seems more useful for debugging.
This commit is contained in:
Alok Singh 2019-09-06 22:08:09 -07:00 committed by William Falcon
parent 9f9d38673e
commit 81df2259ef
1 changed files with 6 additions and 5 deletions

View File

@ -1091,10 +1091,10 @@ class Trainer(TrainerIO):
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)
def __print_nan_grads(self):
if self.print_nan_grads:
model = self.__get_model()
for param in model.parameters():
print(param.grad.float().sum())
model = self.__get_model()
for param in model.parameters():
if torch.isnan(param.grad.float()).any():
print(param, param.grad)
def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
@ -1137,7 +1137,8 @@ class Trainer(TrainerIO):
model_ref.on_after_backward()
# nan grads
self.__print_nan_grads()
if self.print_nan_grads:
self.__print_nan_grads()
# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()