diff --git a/examples/new_project_templates/single_gpu_node_template.py b/examples/new_project_templates/single_gpu_node_template.py new file mode 100644 index 0000000000..a03be32f17 --- /dev/null +++ b/examples/new_project_templates/single_gpu_node_template.py @@ -0,0 +1,112 @@ +""" +Runs a model on a single node across N-gpus. +""" +import os +import sys +import numpy as np +from time import sleep +import torch + +from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster +from pytorch_lightning.models.trainer import Trainer +from pytorch_lightning.utils.arg_parse import add_default_args + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + +SEED = 2334 +torch.manual_seed(SEED) +np.random.seed(SEED) + +from lightning_module_template import LightningTemplateModel + + +def main(hparams): + """ + Main training routine specific for this project + :param hparams: + :return: + """ + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + print('loading model...') + model = LightningTemplateModel(hparams) + print('model built') + + # ------------------------ + # 2 INIT TEST TUBE EXP + # ------------------------ + + # init experiment + exp = Experiment( + name=hyperparams.experiment_name, + save_dir=hyperparams.test_tube_save_path, + autosave=False, + description='test demo' + ) + + exp.argparse(hparams) + exp.save() + + # ------------------------ + # 3 DEFINE CALLBACKS + # ------------------------ + model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) + early_stop = EarlyStopping( + monitor='val_acc', + patience=3, + verbose=True, + mode='max' + ) + + checkpoint = ModelCheckpoint( + filepath=model_save_path, + save_best_only=True, + verbose=True, + monitor='val_loss', + mode='min' + ) + + # ------------------------ + # 4 INIT TRAINER + # ------------------------ + trainer = Trainer( + experiment=exp, + checkpoint_callback=checkpoint, + early_stop_callback=early_stop, + gpus=hparams.gpus, + ) + + # ------------------------ + # 5 START TRAINING + # ------------------------ + trainer.fit(model) + + +if __name__ == '__main__': + + # dirs + root_dir = os.path.dirname(os.path.realpath(__file__)) + demo_log_dir = os.path.join(root_dir, 'pt_lightning_demo_logs') + checkpoint_dir = os.path.join(demo_log_dir, 'model_weights') + test_tube_dir = os.path.join(demo_log_dir, 'test_tube_data') + + # although we user hyperOptParser, we are using it only as argparse right now + parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False) + + # gpu args + parent_parser.add_argument('--gpus', type=str, default='-1', help='how many gpus to use in the node') + parent_parser.add_argument('--test_tube_save_path', type=str, default=test_tube_dir, help='where to save logs') + parent_parser.add_argument('--model_save_path', type=str, default=checkpoint_dir, help='where to save model') + parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a', help='test tube exp name') + + # allow model to overwrite or extend args + parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir) + hyperparams = parser.parse_args() + + # --------------------- + # RUN TRAINING + # --------------------- + # run on HPC cluster + print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {hyperparams.gpus}') + main(hyperparams) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index fe448972ae..f2252186cf 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -302,8 +302,7 @@ class Trainer(TrainerIO): world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids) # set up server using proc 0's ip address - ip_tables_dir = os.path.join(self.cluster.log_path, 'ip_tables') - ip = self.__get_root_node_ip(self.proc_rank, self.nb_gpu_nodes, ip_tables_dir) + ip = self.__get_root_node_ip(self.proc_rank, self.nb_gpu_nodes) dist.init_process_group("nccl", init_method=f'tcp://{ip}:12001', rank=self.proc_rank, world_size=world_size) print(f"GPU: {gpu_nb} - Rank: {self.proc_rank}") @@ -315,7 +314,7 @@ class Trainer(TrainerIO): # continue training routine self.__run_pretrain_routine(model) - def __get_root_node_ip(self, world_gpu_nb, nb_gpu_nodes, ip_file_dir): + def __get_root_node_ip(self, world_gpu_nb, nb_gpu_nodes): """ Resolves the ip address of proc 0. Proc 0 writes address to a file. Every other process waits until the ip is available before it starts @@ -329,6 +328,9 @@ class Trainer(TrainerIO): if nb_gpu_nodes == 1: return '127.0.0.1' + # where to store ip_table + ip_file_dir = os.path.join(self.cluster.log_path, 'ip_tables') + # the first gpu in the world becomes the host # this is based on its global rank # it communicates its ip by saving an ip_table to the slurm cluster logging dir