diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 012e36ab31..7d207502f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -170,6 +170,7 @@ class Trainer(TrainerIO): # allow int, string and gpu list self.data_parallel_device_ids = self.__parse_gpu_ids(gpus) + self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids) # distributed backend choice self.use_ddp = False @@ -270,6 +271,17 @@ class Trainer(TrainerIO): return gpus + def __set_root_gpu(self, gpus): + if gpus is None: + return None + + # set root gpu + root_gpu = 0 + if type(gpus) is list: + root_gpu = gpus[0] + + return root_gpu + @property def num_gpus(self): gpus = self.data_parallel_device_ids @@ -701,10 +713,7 @@ class Trainer(TrainerIO): # allow for lr schedulers as well self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) - root_gpu = 0 - if type(self.data_parallel_device_ids) is list: - root_gpu = self.data_parallel_device_ids[0] - model.cuda(root_gpu) + model.cuda(self.root_gpu) if self.use_amp: # An example @@ -721,10 +730,7 @@ class Trainer(TrainerIO): # allow for lr schedulers as well self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) - root_gpu = 0 - if type(self.data_parallel_device_ids) is list: - root_gpu = self.data_parallel_device_ids[0] - model.cuda(root_gpu) + model.cuda(self.root_gpu) # check for this bug (amp + dp + !01 doesn't work) # https://github.com/NVIDIA/apex/issues/227 @@ -736,7 +742,12 @@ class Trainer(TrainerIO): """ raise MisconfigurationException(m) - model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) + # create list of device ids + device_ids = self.data_parallel_device_ids + if type(device_ids) is int: + device_ids = list(range(device_ids)) + + model = LightningDataParallel(model, device_ids=device_ids) self.__run_pretrain_routine(model) @@ -787,6 +798,9 @@ class Trainer(TrainerIO): torch.cuda.set_device(gpu_nb) model.cuda(gpu_nb) + # override root GPU + self.root_gpu = gpu_nb + # AMP # run through amp wrapper before going to distributed DP if self.use_amp: diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index f5a869a883..71b677cc41 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -1,6 +1,7 @@ import os import re import signal +import pdb from subprocess import call import torch @@ -78,7 +79,7 @@ class TrainerIO(object): except Exception as e: pass - if on_slurm and self.proc_rank == 0: + if on_slurm: print('set slurm handle signals') signal.signal(signal.SIGUSR1, self.sig_handler) signal.signal(signal.SIGTERM, self.term_handler) @@ -103,6 +104,9 @@ class TrainerIO(object): else: print('requeue failed...') + # close experiment to avoid issues + self.experiment.close() + def term_handler(self, signum, frame): # save print("bypassing sigterm") @@ -118,19 +122,22 @@ class TrainerIO(object): def restore(self, checkpoint_path, on_gpu): - if on_gpu: - checkpoint = torch.load(checkpoint_path) - else: - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - - # load training state (affects trainer only) - self.restore_training_state(checkpoint) + # if on_gpu: + # checkpoint = torch.load(checkpoint_path) + # else: + # load on CPU first + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) # load model state model = self.__get_model() # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + if on_gpu: + model.cuda(self.root_gpu) + + # load training state (affects trainer only) + self.restore_training_state(checkpoint) def dump_checkpoint(self): @@ -210,6 +217,14 @@ class TrainerIO(object): for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) + # move optimizer to GPU 1 weight at a time + # avoids OOM + if self.root_gpu is not None: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.root_gpu) + # restore the lr schedulers lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): @@ -225,9 +240,6 @@ class TrainerIO(object): # save exp to make sure we get all the metrics experiment.save() - # close experiment to avoid issues - experiment.close() - ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 if not os.path.exists(folderpath): @@ -248,13 +260,8 @@ class TrainerIO(object): def hpc_load(self, folderpath, on_gpu): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath)) - if on_gpu: - checkpoint = torch.load(filepath) - else: - checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) - - # load training state (affects trainer only) - self.restore_training_state(checkpoint) + # load on CPU first + checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) # load model state model = self.__get_model() @@ -262,9 +269,17 @@ class TrainerIO(object): # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + if self.root_gpu is not None: + model.cuda(self.root_gpu) + + # load training state (affects trainer only) + self.restore_training_state(checkpoint) + # call model hook model.on_hpc_load(checkpoint) + print(f'restored hpc model from: {filepath}') + def max_ckpt_in_folder(self, path, name_key='ckpt_'): files = os.listdir(path) files = [x for x in files if name_key in x] diff --git a/tests/debug.py b/tests/debug.py index 5efa25da62..6016dee3ba 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -214,10 +214,20 @@ def get_hparams(continue_training=False, hpc_exp_number=0): def main(): - """Verify test() on fitted model""" + """ + Make sure DDP + AMP continue training correctly + :return: + """ hparams = get_hparams() model = LightningTestModel(hparams) + trainer_options = dict( + show_progress_bar=True, + max_nb_epochs=4, + gpus=2, + distributed_backend='dp', + ) + save_dir = init_save_dir() # exp file to get meta @@ -228,31 +238,59 @@ def main(): # exp file to get weights checkpoint = ModelCheckpoint(save_dir) - trainer_options = dict( - show_progress_bar=False, - max_nb_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - checkpoint_callback=checkpoint, - experiment=exp, - gpus=[0, 1], - distributed_backend='ddp' - ) + # add these to the trainer options + trainer_options['experiment'] = exp + trainer_options['checkpoint_callback'] = checkpoint # fit model trainer = Trainer(**trainer_options) + trainer.is_slurm_managing_tasks = True result = trainer.fit(model) + # track epoch before saving + real_global_epoch = trainer.current_epoch + # correct result and ok accuracy - assert result == 1, 'training failed to complete' - pretrained_model = load_model(exp, save_dir, on_gpu=True, module_class=LightningTestModel) + assert result == 1, 'amp + dp model failed to complete' + # --------------------------- + # HPC LOAD/SAVE + # --------------------------- + # save + trainer.hpc_save(save_dir, exp) + + # init new trainer + new_exp = get_exp(False, version=exp.version) + trainer_options['experiment'] = new_exp + trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir) + trainer_options['train_percent_check'] = 0.2 + trainer_options['val_percent_check'] = 0.2 + trainer_options['max_nb_epochs'] = 1 new_trainer = Trainer(**trainer_options) - new_trainer.test(pretrained_model) - # test we have good test accuracy - assert_ok_test_acc(new_trainer) - # clear_save_dir() + # set the epoch start hook so we can predict before the model does the full training + def assert_good_acc(): + assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 + + # if model and state loaded correctly, predictions will be good even though we + # haven't trained with the new loaded model + dp_model = new_trainer.model + dp_model.eval() + + _ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader] + + # new model + model = LightningTestModel(hparams) + model.on_sanity_check_start = assert_good_acc + + # fit new model which should load hpc weights + new_trainer.fit(model) + + # test freeze on gpu + model.freeze() + model.unfreeze() + + clear_save_dir() if __name__ == '__main__': diff --git a/tests/test_models.py b/tests/test_models.py index d0214ea321..732aa5b95c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -39,6 +39,89 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_dp_resume(): + """ + Make sure DP continues training correctly + :return: + """ + if not can_run_gpu_test(): + return + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + show_progress_bar=True, + max_nb_epochs=2, + gpus=2, + distributed_backend='dp', + ) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + # exp file to get weights + checkpoint = ModelCheckpoint(save_dir) + + # add these to the trainer options + trainer_options['experiment'] = exp + trainer_options['checkpoint_callback'] = checkpoint + + # fit model + trainer = Trainer(**trainer_options) + trainer.is_slurm_managing_tasks = True + result = trainer.fit(model) + + # track epoch before saving + real_global_epoch = trainer.current_epoch + + # correct result and ok accuracy + assert result == 1, 'amp + dp model failed to complete' + + # --------------------------- + # HPC LOAD/SAVE + # --------------------------- + # save + trainer.hpc_save(save_dir, exp) + + # init new trainer + new_exp = get_exp(False, version=exp.version) + trainer_options['experiment'] = new_exp + trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir) + trainer_options['train_percent_check'] = 0.2 + trainer_options['val_percent_check'] = 0.2 + trainer_options['max_nb_epochs'] = 1 + new_trainer = Trainer(**trainer_options) + + # set the epoch start hook so we can predict before the model does the full training + def assert_good_acc(): + assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0 + + # if model and state loaded correctly, predictions will be good even though we + # haven't trained with the new loaded model + dp_model = new_trainer.model + dp_model.eval() + + _ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader] + + # new model + model = LightningTestModel(hparams) + model.on_sanity_check_start = assert_good_acc + + # fit new model which should load hpc weights + new_trainer.fit(model) + + # test freeze on gpu + model.freeze() + model.unfreeze() + + clear_save_dir() + + def test_running_test_pretrained_model_ddp(): """Verify test() on pretrained model""" if not can_run_gpu_test(): @@ -1342,7 +1425,7 @@ def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningT return trained_model -def run_prediction(dataloader, trained_model): +def run_prediction(dataloader, trained_model, dp=False): # run prediction on 1 batch for batch in dataloader: break @@ -1350,13 +1433,19 @@ def run_prediction(dataloader, trained_model): x, y = batch x = x.view(x.size(0), -1) - y_hat = trained_model(x) + if dp: + output = trained_model(batch, 0) + acc = output['val_acc'] + acc = torch.mean(acc).item() - # acc - labels_hat = torch.argmax(y_hat, dim=1) - acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - acc = torch.tensor(acc) - acc = acc.item() + else: + y_hat = trained_model(x) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + acc = torch.tensor(acc) + acc = acc.item() assert acc > 0.50, f'this model is expected to get > 0.50 in test set (it got {acc})'