lightning/examples/new_project_templates/multi_node_cluster_template.py

173 lines
5.4 KiB
Python
Raw Normal View History

2019-06-27 18:29:44 +00:00
import os
import sys
import numpy as np
from time import sleep
import torch
from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)
# ---------------------
# DEFINE MODEL HERE
# ---------------------
2019-06-29 21:33:10 +00:00
from lightning_module_template import LightningTemplateModel
2019-06-27 18:29:44 +00:00
# ---------------------
"""
Allows training by using command line arguments
Run by:
# TYPE YOUR RUN COMMAND HERE
"""
def main_local(hparams):
main(hparams, None, None)
def main(hparams, cluster, results_dict):
"""
Main training routine specific for this project
:param hparams:
:return:
"""
2019-07-08 15:21:28 +00:00
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
print('loading model...')
model = LightningTemplateModel(hparams)
print('model built')
# ------------------------
# 2 INIT TEST TUBE EXP
# ------------------------
2019-07-08 17:48:59 +00:00
# when using grid search, it's possible for all models to start at once
# and use the same test tube experiment version
relative_node_id = int(os.environ['SLURM_NODEID'])
sleep(relative_node_id + 1)
# init experiment
2019-06-27 18:29:44 +00:00
exp = Experiment(
2019-07-08 15:21:28 +00:00
name=hyperparams.experiment_name,
save_dir=hyperparams.test_tube_save_path,
2019-06-27 18:29:44 +00:00
autosave=False,
2019-07-08 17:48:59 +00:00
description='test demo'
2019-06-27 18:29:44 +00:00
)
exp.argparse(hparams)
exp.save()
2019-07-08 15:21:28 +00:00
# ------------------------
# 3 DEFINE CALLBACKS
# ------------------------
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
2019-06-27 18:29:44 +00:00
early_stop = EarlyStopping(
2019-07-08 15:21:28 +00:00
monitor='val_acc',
patience=3,
2019-06-27 18:29:44 +00:00
verbose=True,
2019-07-08 15:21:28 +00:00
mode='max'
2019-06-27 18:29:44 +00:00
)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
verbose=True,
2019-07-08 15:21:28 +00:00
monitor='val_loss',
mode='min'
2019-06-27 18:29:44 +00:00
)
2019-07-08 15:21:28 +00:00
# ------------------------
# 4 INIT TRAINER
# ------------------------
2019-06-27 18:29:44 +00:00
trainer = Trainer(
2019-07-08 14:59:07 +00:00
experiment=exp,
2019-06-27 18:29:44 +00:00
cluster=cluster,
checkpoint_callback=checkpoint,
early_stop_callback=early_stop,
2019-07-08 13:42:13 +00:00
gpus=hparams.gpus,
2019-07-08 14:16:12 +00:00
nb_gpu_nodes=hyperparams.nb_gpu_nodes
2019-06-27 18:29:44 +00:00
)
2019-07-08 15:21:28 +00:00
# ------------------------
# 5 START TRAINING
# ------------------------
2019-06-27 18:29:44 +00:00
trainer.fit(model)
2019-07-08 15:21:41 +00:00
2019-06-27 18:29:44 +00:00
def optimize_on_cluster(hyperparams):
# enable cluster training
2019-07-08 15:21:28 +00:00
# log all scripts to the test tube folder
2019-06-27 18:29:44 +00:00
cluster = SlurmCluster(
hyperparam_optimizer=hyperparams,
2019-07-08 15:44:11 +00:00
log_path=hyperparams.slurm_log_path,
2019-06-27 18:29:44 +00:00
)
# email for cluster coms
cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True)
# configure cluster
cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus
2019-07-08 14:16:12 +00:00
cluster.per_experiment_nb_nodes = hyperparams.nb_gpu_nodes
cluster.job_time = '2:00:00'
cluster.gpu_type = 'volta'
cluster.memory_mb_per_node = 0
2019-06-27 18:29:44 +00:00
# any modules for code to run in env
2019-07-08 14:16:12 +00:00
cluster.add_command('source activate lightning')
2019-07-08 15:21:28 +00:00
# run only on 32GB voltas
2019-07-08 14:16:12 +00:00
cluster.add_slurm_cmd(cmd='constraint', value='volta32gb', comment='use 32gb gpus')
cluster.add_slurm_cmd(cmd='partition', value=hyperparams.gpu_partition, comment='use 32gb gpus')
2019-06-27 18:29:44 +00:00
# run hopt
2019-07-08 22:32:28 +00:00
# creates and submits jobs to slurm
2019-06-27 18:29:44 +00:00
cluster.optimize_parallel_cluster_gpu(
main,
nb_trials=hyperparams.nb_hopt_trials,
2019-07-08 15:21:28 +00:00
job_name=hyperparams.experiment_name
2019-06-27 18:29:44 +00:00
)
if __name__ == '__main__':
# use default args
root_dir = os.path.dirname(os.path.realpath(__file__))
2019-07-08 15:44:11 +00:00
demo_log_dir = os.path.join(root_dir, 'pt_lightning_demo_logs')
checkpoint_dir = os.path.join(demo_log_dir, 'model_weights')
test_tube_dir = os.path.join(demo_log_dir, 'test_tube_data')
slurm_out_dir = os.path.join(demo_log_dir, 'slurm_scripts')
2019-07-08 15:21:28 +00:00
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
2019-07-08 14:45:35 +00:00
# cluster args not defined inside the model
2019-07-08 18:27:19 +00:00
parent_parser.add_argument('--gpu_partition', type=str, help='consult your cluster manual')
# TODO: make 1 param
parent_parser.add_argument('--per_experiment_nb_gpus', type=int, help='how many gpus to use in a node')
parent_parser.add_argument('--gpus', type=str, default='-1', help='how many gpus to use in the node')
parent_parser.add_argument('--nb_gpu_nodes', type=int, default=1, help='how many nodes to use in a cluster')
parent_parser.add_argument('--test_tube_save_path', type=str, default=test_tube_dir, help='where to save logs')
parent_parser.add_argument('--slurm_log_path', type=str, default=slurm_out_dir, help='where to save slurm meta')
parent_parser.add_argument('--model_save_path', type=str, default=checkpoint_dir, help='where to save model')
parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a', help='test tube exp name')
parent_parser.add_argument('--nb_hopt_trials', type=int, default=1, help='how many grid search trials to run')
2019-07-08 14:18:57 +00:00
2019-06-27 18:29:44 +00:00
# allow model to overwrite or extend args
2019-07-08 14:45:35 +00:00
parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir)
2019-06-27 18:29:44 +00:00
hyperparams = parser.parse_args()
# ---------------------
# RUN TRAINING
# ---------------------
2019-07-08 14:45:35 +00:00
# run on HPC cluster
print('RUNNING ON SLURM CLUSTER')
optimize_on_cluster(hyperparams)