cleaned up progbar (#165)
* cleaned up progbar * cleaned up progbar * cleaned up progbar * cleaned up progbar * cleaned up progbar * cleaned up progbar * cleaned up progbar * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * updated base files * flake 8
This commit is contained in:
parent
2ad9a9708b
commit
4104a0fc47
|
@ -5,7 +5,7 @@ Lighting offers a few options for logging information about model, gpu usage, et
|
|||
#### Display metrics in progress bar
|
||||
``` {.python}
|
||||
# DEFAULT
|
||||
trainer = Trainer(progress_bar=True)
|
||||
trainer = Trainer(show_progress_bar=True)
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -193,7 +193,7 @@ class LightningTemplateModel(LightningModule):
|
|||
train_sampler = None
|
||||
batch_size = self.hparams.batch_size
|
||||
|
||||
if self.trainer.use_ddp:
|
||||
if self.use_ddp:
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ import torch.multiprocessing as mp
|
|||
import torch.distributed as dist
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.root_module.root_module import LightningModule
|
||||
from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
||||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import (
|
||||
|
@ -61,7 +60,7 @@ class Trainer(TrainerIO):
|
|||
current_gpu_name=0,
|
||||
nb_gpu_nodes=1,
|
||||
gpus=None,
|
||||
progress_bar=True,
|
||||
show_progress_bar=True,
|
||||
overfit_pct=0.0,
|
||||
track_grad_norm=-1,
|
||||
check_val_every_n_epoch=1,
|
||||
|
@ -92,7 +91,7 @@ class Trainer(TrainerIO):
|
|||
:param current_gpu_name:
|
||||
:param nb_gpu_nodes:
|
||||
:param gpus:
|
||||
:param progress_bar:
|
||||
:param show_progress_bar:
|
||||
:param overfit_pct:
|
||||
:param track_grad_norm:
|
||||
:param check_val_every_n_epoch:
|
||||
|
@ -122,7 +121,6 @@ class Trainer(TrainerIO):
|
|||
self.track_grad_norm = track_grad_norm
|
||||
self.fast_dev_run = fast_dev_run
|
||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.progress_bar = progress_bar
|
||||
self.experiment = experiment
|
||||
self.exp_save_path = None
|
||||
if self.experiment is not None:
|
||||
|
@ -223,11 +221,14 @@ class Trainer(TrainerIO):
|
|||
|
||||
# training state
|
||||
self.optimizers = None
|
||||
self.prog_bar = None
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.total_batches = 0
|
||||
|
||||
# can't init progress bar here because starting a new process
|
||||
# means the prog_bar won't survive pickling
|
||||
self.show_progress_bar = show_progress_bar
|
||||
|
||||
# logging
|
||||
self.log_save_interval = log_save_interval
|
||||
self.val_check_interval = val_check_interval
|
||||
|
@ -439,8 +440,8 @@ class Trainer(TrainerIO):
|
|||
outputs.append(output)
|
||||
|
||||
# batch done
|
||||
if self.progress_bar and self.prog_bar is not None:
|
||||
self.prog_bar.update(1)
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar.update(1)
|
||||
|
||||
# give model a chance to do something with the outputs (and method defined)
|
||||
val_results = {}
|
||||
|
@ -464,8 +465,8 @@ class Trainer(TrainerIO):
|
|||
:param model:
|
||||
:return:
|
||||
"""
|
||||
self.tng_dataloader = model.tng_dataloader
|
||||
|
||||
self.tng_dataloader = model.tng_dataloader
|
||||
self.test_dataloader = model.test_dataloader
|
||||
self.val_dataloader = model.val_dataloader
|
||||
|
||||
|
@ -476,21 +477,21 @@ class Trainer(TrainerIO):
|
|||
|
||||
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
|
||||
msg = """
|
||||
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
||||
to assign a subset of your data to each process. To silence this warning, pass a
|
||||
DistributedSampler to your DataLoader.
|
||||
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
||||
to assign a subset of your data to each process. To silence this warning, pass a
|
||||
DistributedSampler to your DataLoader.
|
||||
|
||||
ie: this:
|
||||
dataset = myDataset()
|
||||
dataloader = Dataloader(dataset)
|
||||
ie: this:
|
||||
dataset = myDataset()
|
||||
dataloader = Dataloader(dataset)
|
||||
|
||||
becomes:
|
||||
dataset = myDataset()
|
||||
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||
becomes:
|
||||
dataset = myDataset()
|
||||
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||
|
||||
If you want each process to load the full dataset, ignore this warning.
|
||||
"""
|
||||
If you want each process to load the full dataset, ignore this warning.
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
|
||||
if self.use_ddp and self.val_dataloader is not None:
|
||||
|
@ -645,7 +646,7 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
self.experiment = self.experiment.get_non_ddp_exp()
|
||||
|
||||
# show progbar only on prog_rank 0
|
||||
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
|
||||
# determine which process we are and world size
|
||||
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
|
||||
|
@ -736,6 +737,9 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
|
||||
# set local properties on the model
|
||||
ref_model.on_gpu = self.on_gpu
|
||||
ref_model.use_dp = self.use_dp
|
||||
ref_model.use_ddp = self.use_ddp
|
||||
ref_model.use_amp = self.use_amp
|
||||
|
||||
# transfer data loaders from model
|
||||
self.get_dataloaders(ref_model)
|
||||
|
@ -770,10 +774,19 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
if self.cluster is not None: # pragma: no cover
|
||||
self.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# progress bar init
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar = tqdm.tqdm(0, position=self.process_position)
|
||||
|
||||
# run tiny validation (if validation defined) to make sure program won't crash during val
|
||||
ref_model.on_sanity_check_start()
|
||||
if self.val_dataloader is not None:
|
||||
for ds_i, dataloader in enumerate(self.val_dataloader):
|
||||
|
||||
# reset progress_bar limit for sanity check
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar.reset(self.nb_sanity_val_steps)
|
||||
|
||||
self.validate(model, dataloader, self.nb_sanity_val_steps, ds_i)
|
||||
|
||||
# ---------------------------
|
||||
|
@ -793,10 +806,9 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
|
||||
# init progbar when requested
|
||||
if self.progress_bar:
|
||||
self.prog_bar = tqdm.tqdm(range(self.total_batches),
|
||||
position=self.process_position)
|
||||
# init progress_bar when requested
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar.reset(self.total_batches)
|
||||
|
||||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
|
@ -1025,8 +1037,8 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
if response == -1:
|
||||
return -1
|
||||
|
||||
if self.progress_bar:
|
||||
self.prog_bar.update(1)
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar.update(1)
|
||||
|
||||
# call training_step once per optimizer
|
||||
for opt_idx, optimizer in enumerate(self.optimizers):
|
||||
|
@ -1075,10 +1087,10 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||
|
||||
# update progbar
|
||||
if self.progress_bar:
|
||||
if self.show_progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
self.progress_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# activate batch end hook
|
||||
if self.__is_function_implemented('on_batch_end'):
|
||||
|
@ -1115,10 +1127,10 @@ If you want each process to load the full dataset, ignore this warning.
|
|||
model = self.__get_model()
|
||||
model.on_post_performance_check()
|
||||
|
||||
if self.progress_bar:
|
||||
if self.show_progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
self.progress_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# model checkpointing
|
||||
if self.proc_rank == 0 and self.checkpoint_callback is not None:
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import traceback
|
||||
|
||||
|
||||
def data_loader(fn):
|
||||
"""
|
||||
|
@ -17,7 +19,9 @@ def data_loader(fn):
|
|||
value = fn(self) # Lazy evaluation, done only once.
|
||||
except AttributeError as e:
|
||||
# Guard against AttributeError suppression. (Issue #142)
|
||||
raise RuntimeError('An AttributeError was encountered: ' + str(e)) from e
|
||||
traceback.print_exc()
|
||||
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
|
||||
raise RuntimeError(error) from e
|
||||
setattr(self, attr_name, value) # Memoize evaluation.
|
||||
return value
|
||||
|
||||
|
|
|
@ -23,6 +23,9 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
self.use_dp = False
|
||||
self.use_ddp = False
|
||||
self.use_amp = False
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -209,7 +209,7 @@ class LightningTestModel(LightningModule):
|
|||
batch_size = self.hparams.batch_size
|
||||
|
||||
try:
|
||||
if self.on_gpu and not self.force_remove_distributed_sampler:
|
||||
if self.use_ddp and not self.force_remove_distributed_sampler:
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
except Exception:
|
||||
|
|
|
@ -181,7 +181,7 @@ class NoValEndTestModel(LightningModule):
|
|||
batch_size = self.hparams.batch_size
|
||||
|
||||
try:
|
||||
if self.on_gpu and not self.force_remove_distributed_sampler:
|
||||
if self.use_ddp and not self.force_remove_distributed_sampler:
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
except Exception:
|
||||
|
|
|
@ -138,7 +138,7 @@ class NoValModel(LightningModule):
|
|||
batch_size = self.hparams.batch_size
|
||||
|
||||
try:
|
||||
if self.on_gpu and not self.force_remove_distributed_sampler:
|
||||
if self.use_ddp and not self.force_remove_distributed_sampler:
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
except Exception:
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
|||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CoolModel(pl.LightningModule):
|
||||
|
@ -136,41 +137,60 @@ def run_prediction(dataloader, trained_model):
|
|||
assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set (it got %f)' % val_acc
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
trainer = Trainer(
|
||||
experiment=exp,
|
||||
checkpoint_callback=checkpoint,
|
||||
progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='dp',
|
||||
)
|
||||
|
||||
model = CoolModel()
|
||||
# add these to the trainer options
|
||||
trainer_options['checkpoint_callback'] = checkpoint
|
||||
trainer_options['experiment'] = exp
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test model loading
|
||||
pretrained_model = load_model(exp, save_dir)
|
||||
pretrained_model = load_model(exp, save_dir, on_gpu)
|
||||
|
||||
# test model preds
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
if trainer.use_ddp:
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(save_dir, exp)
|
||||
trainer.hpc_load(save_dir, on_gpu=on_gpu)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp'
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -26,6 +26,33 @@ np.random.seed(SEED)
|
|||
# ------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
def test_multi_gpu_model_ddp():
|
||||
"""
|
||||
Make sure DDP works
|
||||
:return:
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn('test_multi_gpu_model_ddp cannot run.'
|
||||
' Rerun on a GPU node to run this test')
|
||||
return
|
||||
if not torch.cuda.device_count() > 1:
|
||||
warnings.warn('test_multi_gpu_model_ddp cannot run.'
|
||||
' Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp'
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_optimizer_return_options():
|
||||
|
||||
|
@ -118,7 +145,7 @@ def test_early_stopping_cpu_model():
|
|||
overfit_pct=0.20,
|
||||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1
|
||||
|
@ -265,7 +292,7 @@ def test_amp_single_gpu():
|
|||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0],
|
||||
distributed_backend='dp',
|
||||
|
@ -361,7 +388,7 @@ def test_amp_gpu_ddp():
|
|||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
|
@ -581,7 +608,7 @@ def test_amp_gpu_ddp_slurm_managed():
|
|||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0],
|
||||
distributed_backend='ddp',
|
||||
|
@ -646,7 +673,7 @@ def test_cpu_model_with_amp():
|
|||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
|
@ -667,7 +694,7 @@ def test_cpu_model():
|
|||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
|
@ -690,7 +717,7 @@ def test_all_features_cpu_model():
|
|||
overfit_pct=0.20,
|
||||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
experiment=get_exp(),
|
||||
accumulate_grad_batches=2,
|
||||
max_nb_epochs=1,
|
||||
|
@ -714,7 +741,7 @@ def test_single_gpu_model():
|
|||
model, hparams = get_model()
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
|
@ -739,7 +766,7 @@ def test_multi_gpu_model_dp():
|
|||
return
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
|
@ -776,34 +803,6 @@ def test_amp_gpu_dp():
|
|||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_ddp():
|
||||
"""
|
||||
Make sure DDP works
|
||||
:return:
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn('test_multi_gpu_model_ddp cannot run.'
|
||||
' Rerun on a GPU node to run this test')
|
||||
return
|
||||
if not torch.cuda.device_count() > 1:
|
||||
warnings.warn('test_multi_gpu_model_ddp cannot run.'
|
||||
' Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp'
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_ddp_sampler_error():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
@ -826,7 +825,7 @@ def test_ddp_sampler_error():
|
|||
|
||||
trainer = Trainer(
|
||||
experiment=exp,
|
||||
progress_bar=False,
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
|
|
Loading…
Reference in New Issue