prog bar option

This commit is contained in:
William Falcon 2019-06-27 11:22:13 -04:00
parent 4f75515ca4
commit b1fdde5daf
4 changed files with 43 additions and 43 deletions

View File

@ -0,0 +1,34 @@
The asdf
---
#### Accumulated gradients
Accumulated gradients runs K small batches of size N before doing a backwards pass. The effect is a large effective batch size of size KxN.
``` {.python}
# default 1 (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)
```
---
#### Check GPU usage
Lightning automatically logs gpu usage to the test tube logs. It'll only do it at the metric logging interval, so it doesn't slow down training.
---
#### Check which gradients are nan
This option prints a list of tensors with nan gradients.
``` {.python}
trainer = Trainer(check_grad_nans=False)
```
---
#### Check validation every n epochs
If you have a small dataset you might want to check validation every n epochs
``` {.python}
trainer = Trainer(check_val_every_n_epoch=1)
```
---
#### Display metrics in progress bar
``` {.python}
trainer = Trainer(progress_bar=True)
```

View File

@ -16,41 +16,6 @@ trainer.fit(model)
But of course the fun is in all the advanced things it can do:
``` {.python}
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, SlurmCluster
trainer = Trainer(
experiment=Experiment,
checkpoint_callback=ModelCheckpoint,
early_stop_callback=EarlyStopping,
cluster=SlurmCluster,
process_position=0,
current_gpu_name=0,
gpus=None,
enable_tqdm=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
val_check_interval=0.95,
log_save_interval=1, add_log_row_interval=1,
lr_scheduler_milestones=None,
use_amp=False,
check_grad_nans=False,
amp_level='O2',
nb_sanity_val_steps=5):
)
```
Things you can do with the trainer module:
**Training loop**
- Accumulate gradients
@ -58,6 +23,7 @@ Things you can do with the trainer module:
- Check which gradients are nan
- Check validation every n epochs
- Display metrics in progress bar
- Display the parameter count by layer
- Force training for min or max epochs
- Inspect gradient norms
- Learning rate annealing

View File

@ -37,7 +37,7 @@ class Trainer(TrainerIO):
process_position=0,
current_gpu_name=0,
gpus=None,
enable_tqdm=True,
progress_bar=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
@ -58,7 +58,7 @@ class Trainer(TrainerIO):
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.enable_tqdm = enable_tqdm
self.progress_bar = progress_bar
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
@ -206,7 +206,7 @@ class Trainer(TrainerIO):
outputs.append(output)
# batch done
if self.enable_tqdm and self.prog_bar is not None:
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)
# give model a chance to do something with the outputs
@ -307,7 +307,7 @@ class Trainer(TrainerIO):
self.batch_loss_value = 0 # accumulated grads
# init progbar when requested
if self.enable_tqdm:
if self.progress_bar:
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
for batch_nb, data_batch in enumerate(self.tng_dataloader):
@ -403,7 +403,7 @@ class Trainer(TrainerIO):
if response == -1:
return -1
if self.enable_tqdm:
if self.progress_bar:
self.prog_bar.update(1)
# forward pass
@ -453,7 +453,7 @@ class Trainer(TrainerIO):
self.avg_loss = np.mean(self.running_loss[-100:])
# update progbar
if self.enable_tqdm:
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
@ -495,7 +495,7 @@ class Trainer(TrainerIO):
print(e)
print(traceback.print_exc())
if self.enable_tqdm:
if self.progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)

View File

@ -97,7 +97,7 @@ def main(hparams, cluster, results_dict):
experiment=exp,
on_gpu=on_gpu,
cluster=cluster,
enable_tqdm=hparams.enable_tqdm,
progress_bar=hparams.enable_tqdm,
overfit_pct=hparams.overfit,
track_grad_norm=hparams.track_grad_norm,
fast_dev_run=hparams.fast_dev_run,