commit
56d41eaa8c
|
@ -86,7 +86,7 @@ class TrainerIO(object):
|
|||
# --------------------
|
||||
# HPC IO
|
||||
# --------------------
|
||||
def enable_auto_hpc_walltime_manager(self): # pragma: no cover
|
||||
def enable_auto_hpc_walltime_manager(self):
|
||||
if self.cluster is None:
|
||||
return
|
||||
|
||||
|
@ -157,6 +157,8 @@ class TrainerIO(object):
|
|||
# do the actual save
|
||||
torch.save(checkpoint_dict, filepath)
|
||||
|
||||
return filepath
|
||||
|
||||
def hpc_load(self, folderpath, on_gpu):
|
||||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
|
||||
|
||||
|
|
|
@ -3,16 +3,18 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
|
||||
from pytorch_lightning.testing_models.lm_test_module import LightningTestModel
|
||||
from argparse import Namespace
|
||||
from test_tube import Experiment
|
||||
from test_tube import Experiment, SlurmCluster
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.utils.debugging import MisconfigurationException
|
||||
from pytorch_lightning.root_module import memory
|
||||
from pytorch_lightning.models.trainer import reduce_distributed_output
|
||||
from pytorch_lightning.root_module import model_saving
|
||||
import numpy as np
|
||||
import warnings
|
||||
import torch
|
||||
import os
|
||||
import shutil
|
||||
import pdb
|
||||
|
||||
SEED = 2334
|
||||
torch.manual_seed(SEED)
|
||||
|
@ -22,6 +24,25 @@ np.random.seed(SEED)
|
|||
# ------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
def test_loading_meta_tags():
|
||||
hparams = get_hparams()
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# save tags
|
||||
exp = get_exp(False)
|
||||
exp.tag({'some_str':'a_str', 'an_int': 1, 'a_float': 2.0})
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
# load tags
|
||||
tags_path = exp.get_data_path(exp.name, exp.version) + '/meta_tags.csv'
|
||||
tags = model_saving.load_hparams_from_tags_csv(tags_path)
|
||||
|
||||
assert tags.batch_size == 32 and tags.hidden_dim == 1000
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
def test_dp_output_reduce():
|
||||
|
||||
# test identity when we have a single gpu
|
||||
|
@ -43,6 +64,68 @@ def test_dp_output_reduce():
|
|||
assert reduced['b']['c'] == out['b']['c']
|
||||
|
||||
|
||||
def test_cpu_slurm_saving_loading():
|
||||
"""
|
||||
Verify model save/load/checkpoint on CPU
|
||||
:return:
|
||||
"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
cluster=SlurmCluster(),
|
||||
experiment=exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir)
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_step = trainer.global_step
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test saving checkpoint
|
||||
ckpt_test = os.path.join(save_dir, 'test.ckpt')
|
||||
trainer.save_checkpoint(ckpt_test)
|
||||
|
||||
# test registering a save function
|
||||
trainer.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# test model loading with a map_location
|
||||
pretrained_model = load_model(exp, save_dir, True)
|
||||
|
||||
# test model preds
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC saving
|
||||
saved_filepath = trainer.hpc_save(save_dir, exp)
|
||||
assert os.path.exists(saved_filepath)
|
||||
|
||||
# test HPC loading
|
||||
trainer.global_step = 20000000
|
||||
trainer.hpc_load(save_dir, on_gpu=False)
|
||||
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
|
||||
|
||||
# test freeze on gpu
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_amp_gpu_ddp_slurm_managed():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
|
Loading…
Reference in New Issue