test memory printing
This commit is contained in:
parent
ffdf11b7ed
commit
7f420c0cc2
|
@ -6,6 +6,7 @@ from argparse import Namespace
|
|||
from test_tube import Experiment
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.utils.debugging import MisconfigurationException
|
||||
from pytorch_lightning.root_module import memory
|
||||
import numpy as np
|
||||
import warnings
|
||||
import torch
|
||||
|
@ -128,6 +129,11 @@ def test_multi_gpu_model_dp():
|
|||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
# test memory helper functions
|
||||
memory.count_mem_items()
|
||||
memory.print_mem_stack()
|
||||
memory.get_gpu_memory_map()
|
||||
|
||||
|
||||
def test_amp_gpu_dp():
|
||||
"""
|
||||
|
@ -234,14 +240,12 @@ def test_ddp_sampler_error():
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
# test memory gathering
|
||||
trainer.count_mem_items()
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.get_dataloaders(model)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# UTILS
|
||||
# ------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue