cleaned up progbar (#165)

* cleaned up progbar

* cleaned up progbar

* cleaned up progbar

* cleaned up progbar

* cleaned up progbar

* cleaned up progbar

* cleaned up progbar

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* updated base files

* flake 8
This commit is contained in:
William Falcon 2019-08-23 21:23:27 -04:00 committed by GitHub
parent 2ad9a9708b
commit 4104a0fc47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 126 additions and 88 deletions

View File

@ -5,7 +5,7 @@ Lighting offers a few options for logging information about model, gpu usage, et
#### Display metrics in progress bar
``` {.python}
# DEFAULT
trainer = Trainer(progress_bar=True)
trainer = Trainer(show_progress_bar=True)
```
---

View File

@ -193,7 +193,7 @@ class LightningTemplateModel(LightningModule):
train_sampler = None
batch_size = self.hparams.batch_size
if self.trainer.use_ddp:
if self.use_ddp:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size

View File

@ -14,7 +14,6 @@ import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim.optimizer import Optimizer
from pytorch_lightning.root_module.root_module import LightningModule
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import (
@ -61,7 +60,7 @@ class Trainer(TrainerIO):
current_gpu_name=0,
nb_gpu_nodes=1,
gpus=None,
progress_bar=True,
show_progress_bar=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
@ -92,7 +91,7 @@ class Trainer(TrainerIO):
:param current_gpu_name:
:param nb_gpu_nodes:
:param gpus:
:param progress_bar:
:param show_progress_bar:
:param overfit_pct:
:param track_grad_norm:
:param check_val_every_n_epoch:
@ -122,7 +121,6 @@ class Trainer(TrainerIO):
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar
self.experiment = experiment
self.exp_save_path = None
if self.experiment is not None:
@ -223,11 +221,14 @@ class Trainer(TrainerIO):
# training state
self.optimizers = None
self.prog_bar = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
# can't init progress bar here because starting a new process
# means the prog_bar won't survive pickling
self.show_progress_bar = show_progress_bar
# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
@ -439,8 +440,8 @@ class Trainer(TrainerIO):
outputs.append(output)
# batch done
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)
if self.show_progress_bar:
self.progress_bar.update(1)
# give model a chance to do something with the outputs (and method defined)
val_results = {}
@ -464,8 +465,8 @@ class Trainer(TrainerIO):
:param model:
:return:
"""
self.tng_dataloader = model.tng_dataloader
self.tng_dataloader = model.tng_dataloader
self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader
@ -476,21 +477,21 @@ class Trainer(TrainerIO):
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
You're using multiple gpus and multiple nodes without using a DistributedSampler
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
If you want each process to load the full dataset, ignore this warning.
"""
If you want each process to load the full dataset, ignore this warning.
"""
warnings.warn(msg)
if self.use_ddp and self.val_dataloader is not None:
@ -645,7 +646,7 @@ If you want each process to load the full dataset, ignore this warning.
self.experiment = self.experiment.get_non_ddp_exp()
# show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
# determine which process we are and world size
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
@ -736,6 +737,9 @@ If you want each process to load the full dataset, ignore this warning.
# set local properties on the model
ref_model.on_gpu = self.on_gpu
ref_model.use_dp = self.use_dp
ref_model.use_ddp = self.use_ddp
ref_model.use_amp = self.use_amp
# transfer data loaders from model
self.get_dataloaders(ref_model)
@ -770,10 +774,19 @@ If you want each process to load the full dataset, ignore this warning.
if self.cluster is not None: # pragma: no cover
self.enable_auto_hpc_walltime_manager()
# progress bar init
if self.show_progress_bar:
self.progress_bar = tqdm.tqdm(0, position=self.process_position)
# run tiny validation (if validation defined) to make sure program won't crash during val
ref_model.on_sanity_check_start()
if self.val_dataloader is not None:
for ds_i, dataloader in enumerate(self.val_dataloader):
# reset progress_bar limit for sanity check
if self.show_progress_bar:
self.progress_bar.reset(self.nb_sanity_val_steps)
self.validate(model, dataloader, self.nb_sanity_val_steps, ds_i)
# ---------------------------
@ -793,10 +806,9 @@ If you want each process to load the full dataset, ignore this warning.
self.total_batches = self.nb_tng_batches + self.nb_val_batches
self.batch_loss_value = 0 # accumulated grads
# init progbar when requested
if self.progress_bar:
self.prog_bar = tqdm.tqdm(range(self.total_batches),
position=self.process_position)
# init progress_bar when requested
if self.show_progress_bar:
self.progress_bar.reset(self.total_batches)
# -----------------
# RUN TNG EPOCH
@ -1025,8 +1037,8 @@ If you want each process to load the full dataset, ignore this warning.
if response == -1:
return -1
if self.progress_bar:
self.prog_bar.update(1)
if self.show_progress_bar:
self.progress_bar.update(1)
# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
@ -1075,10 +1087,10 @@ If you want each process to load the full dataset, ignore this warning.
self.avg_loss = np.mean(self.running_loss[-100:])
# update progbar
if self.progress_bar:
if self.show_progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
self.progress_bar.set_postfix(**tqdm_metrics)
# activate batch end hook
if self.__is_function_implemented('on_batch_end'):
@ -1115,10 +1127,10 @@ If you want each process to load the full dataset, ignore this warning.
model = self.__get_model()
model.on_post_performance_check()
if self.progress_bar:
if self.show_progress_bar:
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
self.progress_bar.set_postfix(**tqdm_metrics)
# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None:

View File

@ -1,3 +1,5 @@
import traceback
def data_loader(fn):
"""
@ -17,7 +19,9 @@ def data_loader(fn):
value = fn(self) # Lazy evaluation, done only once.
except AttributeError as e:
# Guard against AttributeError suppression. (Issue #142)
raise RuntimeError('An AttributeError was encountered: ' + str(e)) from e
traceback.print_exc()
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
raise RuntimeError(error) from e
setattr(self, attr_name, value) # Memoize evaluation.
return value

View File

@ -23,6 +23,9 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
# track if gpu was requested for checkpointing
self.on_gpu = False
self.use_dp = False
self.use_ddp = False
self.use_amp = False
def forward(self, *args, **kwargs):
"""

View File

@ -209,7 +209,7 @@ class LightningTestModel(LightningModule):
batch_size = self.hparams.batch_size
try:
if self.on_gpu and not self.force_remove_distributed_sampler:
if self.use_ddp and not self.force_remove_distributed_sampler:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception:

View File

@ -181,7 +181,7 @@ class NoValEndTestModel(LightningModule):
batch_size = self.hparams.batch_size
try:
if self.on_gpu and not self.force_remove_distributed_sampler:
if self.use_ddp and not self.force_remove_distributed_sampler:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception:

View File

@ -138,7 +138,7 @@ class NoValModel(LightningModule):
batch_size = self.hparams.batch_size
try:
if self.on_gpu and not self.force_remove_distributed_sampler:
if self.use_ddp and not self.force_remove_distributed_sampler:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception:

View File

@ -11,6 +11,7 @@ import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
class CoolModel(pl.LightningModule):
@ -136,41 +137,60 @@ def run_prediction(dataloader, trained_model):
assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set (it got %f)' % val_acc
def main():
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)
trainer = Trainer(
experiment=exp,
checkpoint_callback=checkpoint,
progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='dp',
)
model = CoolModel()
# add these to the trainer options
trainer_options['checkpoint_callback'] = checkpoint
trainer_options['experiment'] = exp
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'amp + ddp model failed to complete'
# test model loading
pretrained_model = load_model(exp, save_dir)
pretrained_model = load_model(exp, save_dir, on_gpu)
# test model preds
run_prediction(model.test_dataloader, pretrained_model)
if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
# test HPC loading / saving
trainer.hpc_save(save_dir, exp)
trainer.hpc_load(save_dir, on_gpu=on_gpu)
clear_save_dir()
def main():
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
model, hparams = get_model()
trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp'
)
run_gpu_model_test(trainer_options, model, hparams)
if __name__ == '__main__':
main()

View File

@ -26,6 +26,33 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_multi_gpu_model_ddp():
"""
Make sure DDP works
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_multi_gpu_model_ddp cannot run.'
' Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_multi_gpu_model_ddp cannot run.'
' Rerun on a node with 2+ GPUs to run this test')
return
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
model, hparams = get_model()
trainer_options = dict(
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp'
)
run_gpu_model_test(trainer_options, model, hparams)
def test_optimizer_return_options():
@ -118,7 +145,7 @@ def test_early_stopping_cpu_model():
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
progress_bar=False,
show_progress_bar=False,
experiment=get_exp(),
train_percent_check=0.1,
val_percent_check=0.1
@ -265,7 +292,7 @@ def test_amp_single_gpu():
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
show_progress_bar=True,
max_nb_epochs=1,
gpus=[0],
distributed_backend='dp',
@ -361,7 +388,7 @@ def test_amp_gpu_ddp():
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
show_progress_bar=True,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='ddp',
@ -581,7 +608,7 @@ def test_amp_gpu_ddp_slurm_managed():
model = LightningTestModel(hparams)
trainer_options = dict(
progress_bar=True,
show_progress_bar=True,
max_nb_epochs=1,
gpus=[0],
distributed_backend='ddp',
@ -646,7 +673,7 @@ def test_cpu_model_with_amp():
"""
trainer_options = dict(
progress_bar=False,
show_progress_bar=False,
experiment=get_exp(),
max_nb_epochs=1,
train_percent_check=0.4,
@ -667,7 +694,7 @@ def test_cpu_model():
"""
trainer_options = dict(
progress_bar=False,
show_progress_bar=False,
experiment=get_exp(),
max_nb_epochs=1,
train_percent_check=0.4,
@ -690,7 +717,7 @@ def test_all_features_cpu_model():
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
progress_bar=False,
show_progress_bar=False,
experiment=get_exp(),
accumulate_grad_batches=2,
max_nb_epochs=1,
@ -714,7 +741,7 @@ def test_single_gpu_model():
model, hparams = get_model()
trainer_options = dict(
progress_bar=False,
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
@ -739,7 +766,7 @@ def test_multi_gpu_model_dp():
return
model, hparams = get_model()
trainer_options = dict(
progress_bar=False,
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
@ -776,34 +803,6 @@ def test_amp_gpu_dp():
run_gpu_model_test(trainer_options, model, hparams)
def test_multi_gpu_model_ddp():
"""
Make sure DDP works
:return:
"""
if not torch.cuda.is_available():
warnings.warn('test_multi_gpu_model_ddp cannot run.'
' Rerun on a GPU node to run this test')
return
if not torch.cuda.device_count() > 1:
warnings.warn('test_multi_gpu_model_ddp cannot run.'
' Rerun on a node with 2+ GPUs to run this test')
return
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
model, hparams = get_model()
trainer_options = dict(
progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=[0, 1],
distributed_backend='ddp'
)
run_gpu_model_test(trainer_options, model, hparams)
def test_ddp_sampler_error():
"""
Make sure DDP + AMP work
@ -826,7 +825,7 @@ def test_ddp_sampler_error():
trainer = Trainer(
experiment=exp,
progress_bar=False,
show_progress_bar=False,
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='ddp',