renamed options

This commit is contained in:
William Falcon 2019-06-27 11:27:11 -04:00
parent b1fdde5daf
commit 7aaadad2c6
2 changed files with 17 additions and 6 deletions

View File

@ -17,7 +17,7 @@ Lightning automatically logs gpu usage to the test tube logs. It'll only do it a
#### Check which gradients are nan
This option prints a list of tensors with nan gradients.
``` {.python}
trainer = Trainer(check_grad_nans=False)
trainer = Trainer(print_nan_grads=False)
```
---
@ -31,4 +31,15 @@ trainer = Trainer(check_val_every_n_epoch=1)
#### Display metrics in progress bar
``` {.python}
trainer = Trainer(progress_bar=True)
```
```
---
#### Display the parameter count by layer
By default lightning prints a list of parameters *and submodules* when it starts training.
---
#### Force training for min or max epochs
It can be useful to force training for a minimum number of epochs or limit to a max number
``` {.python}
trainer = Trainer(min_nb_epochs=1, max_nb_epochs=1000)
```

View File

@ -43,12 +43,12 @@ class Trainer(TrainerIO):
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1,
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95,
log_save_interval=1, add_log_row_interval=1,
lr_scheduler_milestones=None,
use_amp=False,
check_grad_nans=False,
print_nan_grads=False,
amp_level='O2',
nb_sanity_val_steps=5):
@ -76,7 +76,7 @@ class Trainer(TrainerIO):
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
self.lr_schedulers = []
self.amp_level = amp_level
self.check_grad_nans = check_grad_nans
self.print_nan_grads = print_nan_grads
self.data_parallel_device_ids = gpus
self.data_parallel = gpus is not None and len(gpus) > 0
@ -427,7 +427,7 @@ class Trainer(TrainerIO):
else:
loss.backward()
if self.check_grad_nans:
if self.print_nan_grads:
model = self.model.module if self.data_parallel else self.model
for param in model.parameters():
print(param.grad.float().sum())