diff --git a/examples/new_project_templates/single_gpu_node_16bit_template.py b/examples/new_project_templates/single_gpu_node_16bit_template.py index d0b28d1cc4..cf51b2c7f2 100644 --- a/examples/new_project_templates/single_gpu_node_16bit_template.py +++ b/examples/new_project_templates/single_gpu_node_16bit_template.py @@ -25,36 +25,19 @@ def main(hparams): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - print('loading model...') model = LightningTemplateModel(hparams) - print('model built') # ------------------------ - # 2 INIT TEST TUBE EXP - # ------------------------ - - # init experiment - exp = Experiment( - name=hyperparams.experiment_name, - save_dir=hyperparams.test_tube_save_path, - autosave=False, - description='test demo' - ) - - exp.argparse(hparams) - exp.save() - - # ------------------------ - # 3 INIT TRAINER + # 2 INIT TRAINER # ------------------------ trainer = Trainer( - experiment=exp, gpus=hparams.gpus, - use_amp=True + use_amp=True, + distributed_backend='dp' ) # ------------------------ - # 4 START TRAINING + # 3 START TRAINING # ------------------------ trainer.fit(model) @@ -63,9 +46,6 @@ if __name__ == '__main__': # dirs root_dir = os.path.dirname(os.path.realpath(__file__)) - demo_log_dir = os.path.join(root_dir, 'pt_lightning_demo_logs') - checkpoint_dir = os.path.join(demo_log_dir, 'model_weights') - test_tube_dir = os.path.join(demo_log_dir, 'test_tube_data') # although we user hyperOptParser, we are using it only as argparse right now parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) @@ -74,12 +54,6 @@ if __name__ == '__main__': parent_parser.add_argument('--gpus', type=str, default='-1', help='how many gpus to use in the node.' 'value -1 uses all the gpus on the node') - parent_parser.add_argument('--test_tube_save_path', type=str, default=test_tube_dir, - help='where to save logs') - parent_parser.add_argument('--model_save_path', type=str, default=checkpoint_dir, - help='where to save model') - parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a', - help='test tube exp name') # allow model to overwrite or extend args parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) @@ -88,6 +62,4 @@ if __name__ == '__main__': # --------------------- # RUN TRAINING # --------------------- - # run on HPC cluster - print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {hyperparams.gpus}') main(hyperparams)