added load on CPU first (#221)

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added load on CPU first

* added print logs

* added print logs

* changed close order

* changed close order
This commit is contained in:
William Falcon 2019-09-11 07:52:36 -04:00 committed by GitHub
parent 90353ac54e
commit 9576dd28b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 207 additions and 51 deletions

View File

@ -170,6 +170,7 @@ class Trainer(TrainerIO):
# allow int, string and gpu list
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids)
# distributed backend choice
self.use_ddp = False
@ -270,6 +271,17 @@ class Trainer(TrainerIO):
return gpus
def __set_root_gpu(self, gpus):
if gpus is None:
return None
# set root gpu
root_gpu = 0
if type(gpus) is list:
root_gpu = gpus[0]
return root_gpu
@property
def num_gpus(self):
gpus = self.data_parallel_device_ids
@ -701,10 +713,7 @@ class Trainer(TrainerIO):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
model.cuda(self.root_gpu)
if self.use_amp:
# An example
@ -721,10 +730,7 @@ class Trainer(TrainerIO):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
root_gpu = 0
if type(self.data_parallel_device_ids) is list:
root_gpu = self.data_parallel_device_ids[0]
model.cuda(root_gpu)
model.cuda(self.root_gpu)
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
@ -736,7 +742,12 @@ class Trainer(TrainerIO):
"""
raise MisconfigurationException(m)
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
# create list of device ids
device_ids = self.data_parallel_device_ids
if type(device_ids) is int:
device_ids = list(range(device_ids))
model = LightningDataParallel(model, device_ids=device_ids)
self.__run_pretrain_routine(model)
@ -787,6 +798,9 @@ class Trainer(TrainerIO):
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)
# override root GPU
self.root_gpu = gpu_nb
# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:

View File

@ -1,6 +1,7 @@
import os
import re
import signal
import pdb
from subprocess import call
import torch
@ -78,7 +79,7 @@ class TrainerIO(object):
except Exception as e:
pass
if on_slurm and self.proc_rank == 0:
if on_slurm:
print('set slurm handle signals')
signal.signal(signal.SIGUSR1, self.sig_handler)
signal.signal(signal.SIGTERM, self.term_handler)
@ -103,6 +104,9 @@ class TrainerIO(object):
else:
print('requeue failed...')
# close experiment to avoid issues
self.experiment.close()
def term_handler(self, signum, frame):
# save
print("bypassing sigterm")
@ -118,19 +122,22 @@ class TrainerIO(object):
def restore(self, checkpoint_path, on_gpu):
if on_gpu:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load model state
model = self.__get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if on_gpu:
model.cuda(self.root_gpu)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
def dump_checkpoint(self):
@ -210,6 +217,14 @@ class TrainerIO(object):
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.root_gpu)
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
@ -225,9 +240,6 @@ class TrainerIO(object):
# save exp to make sure we get all the metrics
experiment.save()
# close experiment to avoid issues
experiment.close()
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
if not os.path.exists(folderpath):
@ -248,13 +260,8 @@ class TrainerIO(object):
def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
if on_gpu:
checkpoint = torch.load(filepath)
else:
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load on CPU first
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
# load model state
model = self.__get_model()
@ -262,9 +269,17 @@ class TrainerIO(object):
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if self.root_gpu is not None:
model.cuda(self.root_gpu)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# call model hook
model.on_hpc_load(checkpoint)
print(f'restored hpc model from: {filepath}')
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = [x for x in files if name_key in x]

View File

@ -214,10 +214,20 @@ def get_hparams(continue_training=False, hpc_exp_number=0):
def main():
"""Verify test() on fitted model"""
"""
Make sure DDP + AMP continue training correctly
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=4,
gpus=2,
distributed_backend='dp',
)
save_dir = init_save_dir()
# exp file to get meta
@ -228,31 +238,59 @@ def main():
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)
trainer_options = dict(
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
checkpoint_callback=checkpoint,
experiment=exp,
gpus=[0, 1],
distributed_backend='ddp'
)
# add these to the trainer options
trainer_options['experiment'] = exp
trainer_options['checkpoint_callback'] = checkpoint
# fit model
trainer = Trainer(**trainer_options)
trainer.is_slurm_managing_tasks = True
result = trainer.fit(model)
# track epoch before saving
real_global_epoch = trainer.current_epoch
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = load_model(exp, save_dir, on_gpu=True, module_class=LightningTestModel)
assert result == 1, 'amp + dp model failed to complete'
# ---------------------------
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.hpc_save(save_dir, exp)
# init new trainer
new_exp = get_exp(False, version=exp.version)
trainer_options['experiment'] = new_exp
trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
trainer_options['train_percent_check'] = 0.2
trainer_options['val_percent_check'] = 0.2
trainer_options['max_nb_epochs'] = 1
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)
# test we have good test accuracy
assert_ok_test_acc(new_trainer)
# clear_save_dir()
# set the epoch start hook so we can predict before the model does the full training
def assert_good_acc():
assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0
# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
_ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader]
# new model
model = LightningTestModel(hparams)
model.on_sanity_check_start = assert_good_acc
# fit new model which should load hpc weights
new_trainer.fit(model)
# test freeze on gpu
model.freeze()
model.unfreeze()
clear_save_dir()
if __name__ == '__main__':

View File

@ -39,6 +39,89 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_dp_resume():
"""
Make sure DP continues training correctly
:return:
"""
if not can_run_gpu_test():
return
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=2,
gpus=2,
distributed_backend='dp',
)
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)
# add these to the trainer options
trainer_options['experiment'] = exp
trainer_options['checkpoint_callback'] = checkpoint
# fit model
trainer = Trainer(**trainer_options)
trainer.is_slurm_managing_tasks = True
result = trainer.fit(model)
# track epoch before saving
real_global_epoch = trainer.current_epoch
# correct result and ok accuracy
assert result == 1, 'amp + dp model failed to complete'
# ---------------------------
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.hpc_save(save_dir, exp)
# init new trainer
new_exp = get_exp(False, version=exp.version)
trainer_options['experiment'] = new_exp
trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
trainer_options['train_percent_check'] = 0.2
trainer_options['val_percent_check'] = 0.2
trainer_options['max_nb_epochs'] = 1
new_trainer = Trainer(**trainer_options)
# set the epoch start hook so we can predict before the model does the full training
def assert_good_acc():
assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0
# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
_ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader]
# new model
model = LightningTestModel(hparams)
model.on_sanity_check_start = assert_good_acc
# fit new model which should load hpc weights
new_trainer.fit(model)
# test freeze on gpu
model.freeze()
model.unfreeze()
clear_save_dir()
def test_running_test_pretrained_model_ddp():
"""Verify test() on pretrained model"""
if not can_run_gpu_test():
@ -1342,7 +1425,7 @@ def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningT
return trained_model
def run_prediction(dataloader, trained_model):
def run_prediction(dataloader, trained_model, dp=False):
# run prediction on 1 batch
for batch in dataloader:
break
@ -1350,13 +1433,19 @@ def run_prediction(dataloader, trained_model):
x, y = batch
x = x.view(x.size(0), -1)
y_hat = trained_model(x)
if dp:
output = trained_model(batch, 0)
acc = output['val_acc']
acc = torch.mean(acc).item()
# acc
labels_hat = torch.argmax(y_hat, dim=1)
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
acc = torch.tensor(acc)
acc = acc.item()
else:
y_hat = trained_model(x)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
acc = torch.tensor(acc)
acc = acc.item()
assert acc > 0.50, f'this model is expected to get > 0.50 in test set (it got {acc})'