speed-up testing (#504)
* extend CI timeout
* add short MNIST
* lower dataset and stop thr
* refactor imports
* formatting
* early stop
* play params
* play params
* minor refactoring
# Conflicts:
# pytorch_lightning/testing/__init__.py
# pytorch_lightning/testing/lm_test_module.py
# pytorch_lightning/testing/lm_test_module_base.py
# pytorch_lightning/testing/lm_test_module_mixins.py
# pytorch_lightning/testing/model.py
# pytorch_lightning/testing/model_base.py
# pytorch_lightning/testing/model_mixins.py
# pytorch_lightning/testing/test_module.py
# pytorch_lightning/testing/test_module_base.py
# pytorch_lightning/testing/test_module_mixins.py
* typo
Co-Authored-By: Ir1dXD <sirius.caffrey@gmail.com>
* Revert "refactor imports"
This reverts commit b86aee92
* update imports
This commit is contained in:
parent
9785a3e78e
commit
47659daa5f
|
@ -22,6 +22,7 @@ references:
|
|||
command: |
|
||||
python --version ; pip --version ; pip list
|
||||
py.test pytorch_lightning tests pl_examples -v --doctest-modules --junitxml=test-reports/pytest_junit.xml --flake8
|
||||
no_output_timeout: 15m
|
||||
|
||||
jobs:
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .test_module import LightningTestModel
|
||||
from .test_module_base import LightningTestModelBase
|
||||
from .test_module_mixins import (
|
||||
from .model import LightningTestModel
|
||||
from .model_base import LightningTestModelBase
|
||||
from .model_mixins import (
|
||||
LightningValidationStepMixin,
|
||||
LightningValidationMixin,
|
||||
LightningValidationStepMultipleDataloadersMixin,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from .test_module_base import LightningTestModelBase
|
||||
from .test_module_mixins import LightningValidationMixin, LightningTestMixin
|
||||
from .model_base import LightningTestModelBase
|
||||
from .model_mixins import LightningValidationMixin, LightningTestMixin
|
||||
|
||||
|
||||
class LightningTestModel(LightningValidationMixin, LightningTestMixin, LightningTestModelBase):
|
|
@ -19,6 +19,22 @@ from pytorch_lightning import data_loader
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
|
||||
class TestingMNIST(MNIST):
|
||||
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None,
|
||||
download=False, num_samples=8000):
|
||||
super(TestingMNIST, self).__init__(
|
||||
root,
|
||||
train=train,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
download=download
|
||||
)
|
||||
# take just a subset of MNIST dataset
|
||||
self.data = self.data[:num_samples]
|
||||
self.targets = self.targets[:num_samples]
|
||||
|
||||
|
||||
class LightningTestModelBase(LightningModule):
|
||||
"""
|
||||
Base LightningModule for testing. Implements only the required
|
||||
|
@ -137,8 +153,8 @@ class LightningTestModelBase(LightningModule):
|
|||
# init data generators
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root=self.hparams.data_root, train=train,
|
||||
transform=transform, download=True)
|
||||
dataset = TestingMNIST(root=self.hparams.data_root, train=train,
|
||||
transform=transform, download=True, num_samples=2000)
|
||||
|
||||
# when using multi-node we need to add the datasampler
|
||||
train_sampler = None
|
|
@ -9,7 +9,7 @@ from pytorch_lightning.testing import (
|
|||
LightningTestModel,
|
||||
)
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from . import testing_utils
|
||||
import tests.utils as tutils
|
||||
|
||||
|
||||
def test_amp_single_gpu():
|
||||
|
@ -17,12 +17,12 @@ def test_amp_single_gpu():
|
|||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
|
@ -33,7 +33,7 @@ def test_amp_single_gpu():
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_no_amp_single_gpu():
|
||||
|
@ -41,12 +41,12 @@ def test_no_amp_single_gpu():
|
|||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
|
@ -58,7 +58,7 @@ def test_no_amp_single_gpu():
|
|||
)
|
||||
|
||||
with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_amp_gpu_ddp():
|
||||
|
@ -66,13 +66,13 @@ def test_amp_gpu_ddp():
|
|||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
|
@ -83,7 +83,7 @@ def test_amp_gpu_ddp():
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_amp_gpu_ddp_slurm_managed():
|
||||
|
@ -91,16 +91,16 @@ def test_amp_gpu_ddp_slurm_managed():
|
|||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
# simulate setting slurm flags
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.set_random_master_port()
|
||||
os.environ['SLURM_LOCALID'] = str(0)
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
|
@ -111,13 +111,13 @@ def test_amp_gpu_ddp_slurm_managed():
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
# add these to the trainer options
|
||||
trainer_options['checkpoint_callback'] = checkpoint
|
||||
|
@ -138,12 +138,11 @@ def test_amp_gpu_ddp_slurm_managed():
|
|||
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
|
||||
|
||||
# test model loading with a map_location
|
||||
pretrained_model = testing_utils.load_model(logger.experiment,
|
||||
trainer.checkpoint_callback.filepath)
|
||||
pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath)
|
||||
|
||||
# test model preds
|
||||
for dataloader in trainer.get_test_dataloaders():
|
||||
testing_utils.run_prediction(dataloader, pretrained_model)
|
||||
tutils.run_prediction(dataloader, pretrained_model)
|
||||
|
||||
if trainer.use_ddp:
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
|
@ -158,7 +157,7 @@ def test_amp_gpu_ddp_slurm_managed():
|
|||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_cpu_model_with_amp():
|
||||
|
@ -166,21 +165,21 @@ def test_cpu_model_with_amp():
|
|||
Make sure model trains on CPU
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
logger=testing_utils.get_test_tube_logger(),
|
||||
logger=tutils.get_test_tube_logger(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4,
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
|
||||
with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
tutils.run_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_amp_gpu_dp():
|
||||
|
@ -188,12 +187,12 @@ def test_amp_gpu_dp():
|
|||
Make sure DP + AMP work
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
gpus='0, 1', # test init with gpu string
|
||||
|
@ -201,7 +200,7 @@ def test_amp_gpu_dp():
|
|||
use_amp=True
|
||||
)
|
||||
with pytest.raises(MisconfigurationException):
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
|
@ -12,7 +12,7 @@ from pytorch_lightning.testing import (
|
|||
LightningTestModelBase,
|
||||
LightningTestMixin,
|
||||
)
|
||||
from . import testing_utils
|
||||
import tests.utils as tutils
|
||||
|
||||
|
||||
def test_early_stopping_cpu_model():
|
||||
|
@ -20,9 +20,9 @@ def test_early_stopping_cpu_model():
|
|||
Test each of the trainer options
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
stopping = EarlyStopping(monitor='val_loss')
|
||||
stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
|
||||
trainer_options = dict(
|
||||
early_stop_callback=stopping,
|
||||
gradient_clip_val=1.0,
|
||||
|
@ -30,13 +30,13 @@ def test_early_stopping_cpu_model():
|
|||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
show_progress_bar=True,
|
||||
logger=testing_utils.get_test_tube_logger(),
|
||||
logger=tutils.get_test_tube_logger(),
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
model, hparams = tutils.get_model()
|
||||
tutils.run_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
# test freeze on cpu
|
||||
model.freeze()
|
||||
|
@ -48,7 +48,7 @@ def test_lbfgs_cpu_model():
|
|||
Test each of the trainer options
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -59,11 +59,11 @@ def test_lbfgs_cpu_model():
|
|||
val_percent_check=0.2
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model(use_test_model=True, lbfgs=True)
|
||||
testing_utils.run_model_test_no_loggers(trainer_options,
|
||||
model, hparams, on_gpu=False, min_acc=0.30)
|
||||
model, hparams = tutils.get_model(use_test_model=True, lbfgs=True)
|
||||
tutils.run_model_test_no_loggers(trainer_options, model, hparams,
|
||||
on_gpu=False, min_acc=0.30)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_default_logger_callbacks_cpu_model():
|
||||
|
@ -71,7 +71,7 @@ def test_default_logger_callbacks_cpu_model():
|
|||
Test each of the trainer options
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -83,30 +83,30 @@ def test_default_logger_callbacks_cpu_model():
|
|||
val_percent_check=0.01
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
testing_utils.run_model_test_no_loggers(trainer_options, model, hparams, on_gpu=False)
|
||||
model, hparams = tutils.get_model()
|
||||
tutils.run_model_test_no_loggers(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
# test freeze on cpu
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_running_test_after_fitting():
|
||||
"""Verify test() on fitted model"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# logger file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -127,29 +127,29 @@ def test_running_test_after_fitting():
|
|||
trainer.test()
|
||||
|
||||
# test we have good test accuracy
|
||||
testing_utils.assert_ok_test_acc(trainer)
|
||||
tutils.assert_ok_test_acc(trainer)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_running_test_without_val():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
"""Verify test() works on a model with no val_loader"""
|
||||
|
||||
class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
|
||||
pass
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# logger file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -170,15 +170,15 @@ def test_running_test_without_val():
|
|||
trainer.test()
|
||||
|
||||
# test we have good test accuracy
|
||||
testing_utils.assert_ok_test_acc(trainer)
|
||||
tutils.assert_ok_test_acc(trainer)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_single_gpu_batch_parse():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
trainer = Trainer()
|
||||
|
@ -224,12 +224,12 @@ def test_simple_cpu():
|
|||
Verify continue training session on CPU
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
trainer_options = dict(
|
||||
|
@ -245,7 +245,7 @@ def test_simple_cpu():
|
|||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_cpu_model():
|
||||
|
@ -253,19 +253,19 @@ def test_cpu_model():
|
|||
Make sure model trains on CPU
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
logger=testing_utils.get_test_tube_logger(),
|
||||
logger=tutils.get_test_tube_logger(),
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
tutils.run_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_all_features_cpu_model():
|
||||
|
@ -273,7 +273,7 @@ def test_all_features_cpu_model():
|
|||
Test each of the trainer options
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer_options = dict(
|
||||
gradient_clip_val=1.0,
|
||||
|
@ -281,15 +281,15 @@ def test_all_features_cpu_model():
|
|||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
show_progress_bar=False,
|
||||
logger=testing_utils.get_test_tube_logger(),
|
||||
logger=tutils.get_test_tube_logger(),
|
||||
accumulate_grad_batches=2,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4
|
||||
)
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
model, hparams = tutils.get_model()
|
||||
tutils.run_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_tbptt_cpu_model():
|
||||
|
@ -297,9 +297,9 @@ def test_tbptt_cpu_model():
|
|||
Test truncated back propagation through time works.
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
truncated_bptt_steps = 2
|
||||
sequence_size = 30
|
||||
|
@ -354,7 +354,7 @@ def test_tbptt_cpu_model():
|
|||
weights_summary=None,
|
||||
)
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
hparams.batch_size = batch_size
|
||||
hparams.in_features = truncated_bptt_steps
|
||||
hparams.hidden_dim = truncated_bptt_steps
|
||||
|
@ -368,7 +368,7 @@ def test_tbptt_cpu_model():
|
|||
|
||||
assert result == 1, 'training failed to complete'
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_single_gpu_model():
|
||||
|
@ -376,13 +376,13 @@ def test_single_gpu_model():
|
|||
Make sure single GPU works (DP mode)
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn('test_single_gpu_model cannot run.'
|
||||
' Rerun on a GPU node to run this test')
|
||||
return
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -392,7 +392,7 @@ def test_single_gpu_model():
|
|||
gpus=1
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytorch_lightning.trainer.dp_mixin import (
|
|||
determine_root_gpu_device,
|
||||
)
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from . import testing_utils
|
||||
import tests.utils as tutils
|
||||
|
||||
PRETEND_N_OF_GPUS = 16
|
||||
|
||||
|
@ -25,13 +25,13 @@ def test_multi_gpu_model_ddp2():
|
|||
Make sure DDP2 works
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
|
@ -42,7 +42,7 @@ def test_multi_gpu_model_ddp2():
|
|||
distributed_backend='ddp2'
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_ddp():
|
||||
|
@ -50,13 +50,13 @@ def test_multi_gpu_model_ddp():
|
|||
Make sure DDP works
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
|
@ -66,14 +66,14 @@ def test_multi_gpu_model_ddp():
|
|||
distributed_backend='ddp'
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_optimizer_return_options():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
trainer = Trainer()
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
|
||||
# single optimizer
|
||||
opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
|
||||
|
@ -105,15 +105,15 @@ def test_cpu_slurm_save_load():
|
|||
Verify model save/load/checkpoint on CPU
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
version = logger.version
|
||||
|
||||
|
@ -149,7 +149,7 @@ def test_cpu_slurm_save_load():
|
|||
assert os.path.exists(saved_filepath)
|
||||
|
||||
# new logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False, version=version)
|
||||
logger = tutils.get_test_tube_logger(False, version=version)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -174,7 +174,7 @@ def test_cpu_slurm_save_load():
|
|||
# and our hook to predict using current model before any more weight updates
|
||||
trainer.fit(model)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_multi_gpu_none_backend():
|
||||
|
@ -183,12 +183,12 @@ def test_multi_gpu_none_backend():
|
|||
distributed_backend = None
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
|
@ -198,7 +198,7 @@ def test_multi_gpu_none_backend():
|
|||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_dp():
|
||||
|
@ -206,12 +206,12 @@ def test_multi_gpu_model_dp():
|
|||
Make sure DP works
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
model, hparams = testing_utils.get_model()
|
||||
model, hparams = tutils.get_model()
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
distributed_backend='dp',
|
||||
|
@ -221,7 +221,7 @@ def test_multi_gpu_model_dp():
|
|||
gpus='-1'
|
||||
)
|
||||
|
||||
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
# test memory helper functions
|
||||
memory.get_memory_profile('min_max')
|
||||
|
@ -232,16 +232,16 @@ def test_ddp_sampler_error():
|
|||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams, force_remove_distributed_sampler=True)
|
||||
|
||||
logger = testing_utils.get_test_tube_logger(True)
|
||||
logger = tutils.get_test_tube_logger(True)
|
||||
|
||||
trainer = Trainer(
|
||||
logger=logger,
|
||||
|
@ -255,7 +255,7 @@ def test_ddp_sampler_error():
|
|||
with pytest.warns(UserWarning):
|
||||
trainer.get_dataloaders(model)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -7,26 +7,20 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.testing import LightningTestModel
|
||||
from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only
|
||||
from . import testing_utils
|
||||
|
||||
RANDOM_FILE_PATHS = list(np.random.randint(12000, 19000, 1000))
|
||||
ROOT_SEED = 1234
|
||||
torch.manual_seed(ROOT_SEED)
|
||||
np.random.seed(ROOT_SEED)
|
||||
RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000))
|
||||
import tests.utils as tutils
|
||||
|
||||
|
||||
def test_testtube_logger():
|
||||
"""
|
||||
verify that basic functionality of test tube logger works
|
||||
"""
|
||||
reset_seed()
|
||||
hparams = testing_utils.get_hparams()
|
||||
tutils.reset_seed()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -39,21 +33,21 @@ def test_testtube_logger():
|
|||
|
||||
assert result == 1, "Training failed"
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_testtube_pickle():
|
||||
"""
|
||||
Verify that pickling a trainer containing a test tube logger works
|
||||
"""
|
||||
reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
|
@ -68,21 +62,21 @@ def test_testtube_pickle():
|
|||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_mlflow_logger():
|
||||
"""
|
||||
verify that basic functionality of mlflow logger works
|
||||
"""
|
||||
reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
try:
|
||||
from pytorch_lightning.logging import MLFlowLogger
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -102,21 +96,21 @@ def test_mlflow_logger():
|
|||
print('result finished')
|
||||
assert result == 1, "Training failed"
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_mlflow_pickle():
|
||||
"""
|
||||
verify that pickling trainer with mlflow logger works
|
||||
"""
|
||||
reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
try:
|
||||
from pytorch_lightning.logging import MLFlowLogger
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -134,21 +128,21 @@ def test_mlflow_pickle():
|
|||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_comet_logger():
|
||||
"""
|
||||
verify that basic functionality of Comet.ml logger works
|
||||
"""
|
||||
reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
try:
|
||||
from pytorch_lightning.logging import CometLogger
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -173,21 +167,21 @@ def test_comet_logger():
|
|||
print('result finished')
|
||||
assert result == 1, "Training failed"
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_comet_pickle():
|
||||
"""
|
||||
verify that pickling trainer with comet logger works
|
||||
"""
|
||||
reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
try:
|
||||
from pytorch_lightning.logging import CometLogger
|
||||
except ModuleNotFoundError:
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -210,7 +204,7 @@ def test_comet_pickle():
|
|||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_custom_logger(tmpdir):
|
||||
|
@ -241,7 +235,7 @@ def test_custom_logger(tmpdir):
|
|||
def version(self):
|
||||
return "1"
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
logger = CustomLogger()
|
||||
|
@ -259,9 +253,3 @@ def test_custom_logger(tmpdir):
|
|||
assert logger.hparams_logged == hparams
|
||||
assert logger.metrics_logged != {}
|
||||
assert logger.finalized_status == "success"
|
||||
|
||||
|
||||
def reset_seed():
|
||||
SEED = RANDOM_SEEDS.pop()
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
|
@ -7,27 +7,27 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.testing import LightningTestModel
|
||||
from . import testing_utils
|
||||
import tests.utils as tutils
|
||||
|
||||
|
||||
def test_running_test_pretrained_model_ddp():
|
||||
"""Verify test() on pretrained model"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
testing_utils.set_random_master_port()
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -49,38 +49,38 @@ def test_running_test_pretrained_model_ddp():
|
|||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = testing_utils.load_model(logger.experiment,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
module_class=LightningTestModel)
|
||||
pretrained_model = tutils.load_model(logger.experiment,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
module_class=LightningTestModel)
|
||||
|
||||
# run test set
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
new_trainer.test(pretrained_model)
|
||||
|
||||
for dataloader in model.test_dataloader():
|
||||
testing_utils.run_prediction(dataloader, pretrained_model)
|
||||
tutils.run_prediction(dataloader, pretrained_model)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_running_test_pretrained_model():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
"""Verify test() on pretrained model"""
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# logger file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
max_nb_epochs=4,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
|
@ -93,7 +93,7 @@ def test_running_test_pretrained_model():
|
|||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = testing_utils.load_model(
|
||||
pretrained_model = tutils.load_model(
|
||||
logger.experiment, trainer.checkpoint_callback.filepath, module_class=LightningTestModel
|
||||
)
|
||||
|
||||
|
@ -101,18 +101,18 @@ def test_running_test_pretrained_model():
|
|||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
testing_utils.assert_ok_test_acc(new_trainer)
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_load_model_from_checkpoint():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
"""Verify test() on pretrained model"""
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -142,27 +142,27 @@ def test_load_model_from_checkpoint():
|
|||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
testing_utils.assert_ok_test_acc(new_trainer)
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_running_test_pretrained_model_dp():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
"""Verify test() on pretrained model"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
# logger file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=True,
|
||||
|
@ -181,16 +181,16 @@ def test_running_test_pretrained_model_dp():
|
|||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = testing_utils.load_model(logger.experiment,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
module_class=LightningTestModel)
|
||||
pretrained_model = tutils.load_model(logger.experiment,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
module_class=LightningTestModel)
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
testing_utils.assert_ok_test_acc(new_trainer)
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.assert_ok_test_acc(new_trainer)
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_dp_resume():
|
||||
|
@ -198,12 +198,12 @@ def test_dp_resume():
|
|||
Make sure DP continues training correctly
|
||||
:return:
|
||||
"""
|
||||
if not testing_utils.can_run_gpu_test():
|
||||
if not tutils.can_run_gpu_test():
|
||||
return
|
||||
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
|
@ -213,14 +213,14 @@ def test_dp_resume():
|
|||
distributed_backend='dp',
|
||||
)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# get logger
|
||||
logger = testing_utils.get_test_tube_logger(debug=False)
|
||||
logger = tutils.get_test_tube_logger(debug=False)
|
||||
|
||||
# exp file to get weights
|
||||
# logger file to get weights
|
||||
checkpoint = testing_utils.init_checkpoint_callback(logger)
|
||||
checkpoint = tutils.init_checkpoint_callback(logger)
|
||||
|
||||
# add these to the trainer options
|
||||
trainer_options['logger'] = logger
|
||||
|
@ -244,7 +244,7 @@ def test_dp_resume():
|
|||
trainer.hpc_save(save_dir, logger)
|
||||
|
||||
# init new trainer
|
||||
new_logger = testing_utils.get_test_tube_logger(version=logger.version)
|
||||
new_logger = tutils.get_test_tube_logger(version=logger.version)
|
||||
trainer_options['logger'] = new_logger
|
||||
trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
|
||||
trainer_options['train_percent_check'] = 0.2
|
||||
|
@ -262,7 +262,7 @@ def test_dp_resume():
|
|||
dp_model.eval()
|
||||
|
||||
dataloader = trainer.get_train_dataloader()
|
||||
testing_utils.run_prediction(dataloader, dp_model, dp=True)
|
||||
tutils.run_prediction(dataloader, dp_model, dp=True)
|
||||
|
||||
# new model
|
||||
model = LightningTestModel(hparams)
|
||||
|
@ -275,7 +275,7 @@ def test_dp_resume():
|
|||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_cpu_restore_training():
|
||||
|
@ -283,16 +283,16 @@ def test_cpu_restore_training():
|
|||
Verify continue training session on CPU
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
test_logger_version = 10
|
||||
logger = testing_utils.get_test_tube_logger(False, version=test_logger_version)
|
||||
logger = tutils.get_test_tube_logger(False, version=test_logger_version)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=2,
|
||||
|
@ -314,7 +314,7 @@ def test_cpu_restore_training():
|
|||
# wipe-out trainer and model
|
||||
# retrain with not much data... this simulates picking training back up after slurm
|
||||
# we want to see if the weights come back correctly
|
||||
new_logger = testing_utils.get_test_tube_logger(False, version=test_logger_version)
|
||||
new_logger = tutils.get_test_tube_logger(False, version=test_logger_version)
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=2,
|
||||
val_check_interval=0.50,
|
||||
|
@ -335,7 +335,7 @@ def test_cpu_restore_training():
|
|||
# haven't trained with the new loaded model
|
||||
trainer.model.eval()
|
||||
for dataloader in trainer.get_val_dataloaders():
|
||||
testing_utils.run_prediction(dataloader, trainer.model)
|
||||
tutils.run_prediction(dataloader, trainer.model)
|
||||
|
||||
model.on_sanity_check_start = assert_good_acc
|
||||
|
||||
|
@ -343,7 +343,7 @@ def test_cpu_restore_training():
|
|||
# and our hook to predict using current model before any more weight updates
|
||||
trainer.fit(model)
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_model_saving_loading():
|
||||
|
@ -351,15 +351,15 @@ def test_model_saving_loading():
|
|||
Tests use case where trainer saves the model, and user loads it from tags independently
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -402,7 +402,7 @@ def test_model_saving_loading():
|
|||
new_pred = model_2(x)
|
||||
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
|
@ -16,7 +16,7 @@ from pytorch_lightning.testing import (
|
|||
)
|
||||
from pytorch_lightning.trainer import trainer_io
|
||||
from pytorch_lightning.trainer.logging_mixin import TrainerLoggingMixin
|
||||
from . import testing_utils
|
||||
import tests.utils as tutils
|
||||
|
||||
|
||||
def test_no_val_module():
|
||||
|
@ -24,19 +24,19 @@ def test_no_val_module():
|
|||
Tests use case where trainer saves the model, and user loads it from tags independently
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
|
||||
class CurrentTestModel(LightningTestModelBase):
|
||||
pass
|
||||
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -63,7 +63,7 @@ def test_no_val_module():
|
|||
model_2.eval()
|
||||
|
||||
# make prediction
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_no_val_end_module():
|
||||
|
@ -71,18 +71,18 @@ def test_no_val_end_module():
|
|||
Tests use case where trainer saves the model, and user loads it from tags independently
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase):
|
||||
pass
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
|
@ -109,11 +109,11 @@ def test_no_val_end_module():
|
|||
model_2.eval()
|
||||
|
||||
# make prediction
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_gradient_accumulation_scheduling():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
"""
|
||||
Test grad accumulation by the freq of optimizer updates
|
||||
|
@ -170,7 +170,7 @@ def test_gradient_accumulation_scheduling():
|
|||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
schedule = {1: 2, 3: 4}
|
||||
|
||||
|
@ -187,13 +187,13 @@ def test_gradient_accumulation_scheduling():
|
|||
|
||||
|
||||
def test_loading_meta_tags():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
from argparse import Namespace
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
|
||||
# save tags
|
||||
logger = testing_utils.get_test_tube_logger(False)
|
||||
logger = tutils.get_test_tube_logger(False)
|
||||
logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0))
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
@ -206,12 +206,12 @@ def test_loading_meta_tags():
|
|||
|
||||
assert tags.batch_size == 32 and tags.hidden_dim == 1000
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_dp_output_reduce():
|
||||
mixin = TrainerLoggingMixin()
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
# test identity when we have a single gpu
|
||||
out = torch.rand(3, 1)
|
||||
|
@ -240,11 +240,11 @@ def test_model_checkpoint_options():
|
|||
def mock_save_function(filepath):
|
||||
open(filepath, 'a').close()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
# simulated losses
|
||||
save_dir = testing_utils.init_save_dir()
|
||||
save_dir = tutils.init_save_dir()
|
||||
losses = [10, 9, 2.8, 5, 2.5]
|
||||
|
||||
# -----------------
|
||||
|
@ -262,7 +262,7 @@ def test_model_checkpoint_options():
|
|||
for i in range(0, len(losses)):
|
||||
assert f'_ckpt_epoch_{i}.ckpt' in file_lists
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
# -----------------
|
||||
# CASE K=0 (none)
|
||||
|
@ -275,7 +275,7 @@ def test_model_checkpoint_options():
|
|||
|
||||
assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
# -----------------
|
||||
# CASE K=1 (2.5, epoch 4)
|
||||
|
@ -289,7 +289,7 @@ def test_model_checkpoint_options():
|
|||
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
||||
assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
# -----------------
|
||||
# CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
|
||||
|
@ -308,7 +308,7 @@ def test_model_checkpoint_options():
|
|||
assert '_ckpt_epoch_2.ckpt' in file_lists
|
||||
assert 'other_file.ckpt' in file_lists
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
# -----------------
|
||||
# CASE K=4 (save all 4 models)
|
||||
|
@ -323,7 +323,7 @@ def test_model_checkpoint_options():
|
|||
|
||||
assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
# -----------------
|
||||
# CASE K=3 (save the 2nd, 3rd, 4th model)
|
||||
|
@ -341,13 +341,13 @@ def test_model_checkpoint_options():
|
|||
assert '_ckpt_epoch_0_v1.ckpt' in file_lists
|
||||
assert '_ckpt_epoch_0.ckpt' in file_lists
|
||||
|
||||
testing_utils.clear_save_dir()
|
||||
tutils.clear_save_dir()
|
||||
|
||||
|
||||
def test_model_freeze_unfreeze():
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
model.freeze()
|
||||
|
@ -359,7 +359,7 @@ def test_multiple_val_dataloader():
|
|||
Verify multiple val_dataloader
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(
|
||||
LightningValidationMultipleDataloadersMixin,
|
||||
|
@ -367,7 +367,7 @@ def test_multiple_val_dataloader():
|
|||
):
|
||||
pass
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# logger file to get meta
|
||||
|
@ -390,7 +390,7 @@ def test_multiple_val_dataloader():
|
|||
|
||||
# make sure predictions are good for each val set
|
||||
for dataloader in trainer.get_val_dataloaders():
|
||||
testing_utils.run_prediction(dataloader, trainer.model)
|
||||
tutils.run_prediction(dataloader, trainer.model)
|
||||
|
||||
|
||||
def test_multiple_test_dataloader():
|
||||
|
@ -398,7 +398,7 @@ def test_multiple_test_dataloader():
|
|||
Verify multiple test_dataloader
|
||||
:return:
|
||||
"""
|
||||
testing_utils.reset_seed()
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(
|
||||
LightningTestMultipleDataloadersMixin,
|
||||
|
@ -406,7 +406,7 @@ def test_multiple_test_dataloader():
|
|||
):
|
||||
pass
|
||||
|
||||
hparams = testing_utils.get_hparams()
|
||||
hparams = tutils.get_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# logger file to get meta
|
||||
|
@ -426,7 +426,7 @@ def test_multiple_test_dataloader():
|
|||
|
||||
# make sure predictions are good for each test set
|
||||
for dataloader in trainer.get_test_dataloaders():
|
||||
testing_utils.run_prediction(dataloader, trainer.model)
|
||||
tutils.run_prediction(dataloader, trainer.model)
|
||||
|
||||
# run the test method
|
||||
trainer.test()
|
||||
|
|
|
@ -52,7 +52,7 @@ def run_model_test_no_loggers(trainer_options, model, hparams, on_gpu=True, min_
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||
def run_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# logger file to get meta
|
Loading…
Reference in New Issue