diff --git a/docs/Trainer/Logging.md b/docs/Trainer/Logging.md index 5fc6cc4f8e..f3e93b1409 100644 --- a/docs/Trainer/Logging.md +++ b/docs/Trainer/Logging.md @@ -127,7 +127,13 @@ trainer = Trainer(row_log_interval=10) Logs GPU memory when metrics are logged. ``` {.python} # DEFAULT -trainer = Trainer(log_gpu_memory=False) +trainer = Trainer(log_gpu_memory=None) + +# log only the min/max utilization +trainer = Trainer(log_gpu_memory='min_max') + +# log all the GPU memory (if on DDP, logs only that node) +trainer = Trainer(log_gpu_memory='all') ``` --- diff --git a/pytorch_lightning/root_module/memory.py b/pytorch_lightning/root_module/memory.py index b4fe297703..6aa4519ec0 100644 --- a/pytorch_lightning/root_module/memory.py +++ b/pytorch_lightning/root_module/memory.py @@ -178,6 +178,33 @@ def count_mem_items(): # pragma: no cover return nb_params, nb_tensors +def get_memory_profile(mode): + """ + 'all' means return memory for all gpus + 'min_max' means return memory for max and min + :param mode: + :return: + """ + memory_map = get_gpu_memory_map() + + if mode == 'min_max': + min_mem = 1000000 + min_k = None + max_mem = 0 + max_k = None + for k, v in memory_map: + if v > max_mem: + max_mem = v + max_k = k + if v < min_mem: + min_mem = v + min_k = k + + memory_map = {min_k: min_mem, max_k: max_mem} + + return memory_map + + def get_gpu_memory_map(): """Get the current gpu usage. @@ -196,6 +223,6 @@ def get_gpu_memory_map(): gpu_memory = [int(x) for x in result.strip().split('\n')] gpu_memory_map = {} for k, v in zip(range(len(gpu_memory)), gpu_memory): - k = 'gpu_%i' % k + k = f'gpu_{k}' gpu_memory_map[k] = v return gpu_memory_map diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c004a4cf8d..cb0e9b858c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,7 @@ import torch.distributed as dist from torch.optim.optimizer import Optimizer from pytorch_lightning.root_module.root_module import LightningModule -from pytorch_lightning.root_module.memory import get_gpu_memory_map +from pytorch_lightning.root_module import memory from pytorch_lightning.logging import TestTubeLogger from pytorch_lightning.trainer.trainer_io import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( @@ -66,7 +66,7 @@ class Trainer(TrainerIO): process_position=0, nb_gpu_nodes=1, gpus=None, - log_gpu_memory=False, + log_gpu_memory=None, show_progress_bar=True, overfit_pct=0.0, track_grad_norm=-1, @@ -98,7 +98,7 @@ class Trainer(TrainerIO): :param process_position: shown in the tqdm bar :param nb_gpu_nodes: number of GPU nodes :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1' - :param log_gpu_memory: Bool. If true, adds memory logs + :param log_gpu_memory: str. None, 'min_max', 'all' :param show_progress_bar: Bool. If true shows tqdm bar :param overfit_pct: float. uses this much of all datasets :param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm @@ -1080,8 +1080,8 @@ class Trainer(TrainerIO): metrics = self.__training_tqdm_dict # add gpu memory - if self.on_gpu and self.log_gpu_memory: - mem_map = get_gpu_memory_map() + if self.on_gpu and self.log_gpu_memory is not None: + mem_map = memory.get_memory_profile(mode=self.log_gpu_memory) metrics.update(mem_map) # add norms diff --git a/tests/test_models.py b/tests/test_models.py index c70663d0a6..9d9536e4bb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,6 +31,10 @@ from pytorch_lightning.trainer import trainer_io from pytorch_lightning.logging import TestTubeLogger from examples import LightningTemplateModel +# generate a list of random seeds for each test +ROOT_SEED = 1234 +torch.manual_seed(ROOT_SEED) +np.random.seed(ROOT_SEED) RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000)) @@ -75,8 +79,8 @@ def test_lbfgs_cpu_model(): overfit_pct=0.20, print_nan_grads=True, show_progress_bar=False, - train_percent_check=0.1, - val_percent_check=0.1 + train_percent_check=0.2, + val_percent_check=0.2 ) model, hparams = get_model(use_test_model=True, lbfgs=True)