test memory printing

This commit is contained in:
William Falcon 2019-07-24 17:41:08 -04:00
parent ffdf11b7ed
commit 7f420c0cc2
1 changed files with 7 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from argparse import Namespace
from test_tube import Experiment
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utils.debugging import MisconfigurationException
from pytorch_lightning.root_module import memory
import numpy as np
import warnings
import torch
@ -128,6 +129,11 @@ def test_multi_gpu_model_dp():
run_gpu_model_test(trainer_options, model, hparams)
# test memory helper functions
memory.count_mem_items()
memory.print_mem_stack()
memory.get_gpu_memory_map()
def test_amp_gpu_dp():
"""
@ -234,14 +240,12 @@ def test_ddp_sampler_error():
use_amp=True
)
# test memory gathering
trainer.count_mem_items()
with pytest.raises(MisconfigurationException):
trainer.get_dataloaders(model)
clear_save_dir()
# ------------------------------------------------------------------------
# UTILS
# ------------------------------------------------------------------------