testing slurm ddp
This commit is contained in:
parent
ae0349d449
commit
01b9502847
|
@ -209,8 +209,6 @@ class LightningTemplateModel(LightningModule):
|
||||||
# parser.set_defaults(gradient_clip=5.0)
|
# parser.set_defaults(gradient_clip=5.0)
|
||||||
|
|
||||||
# network params
|
# network params
|
||||||
|
|
||||||
parser.add_argument('--nb_gpu_nodes', type=int, default=1)
|
|
||||||
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
|
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
|
||||||
parser.add_argument('--in_features', default=28*28)
|
parser.add_argument('--in_features', default=28*28)
|
||||||
parser.add_argument('--out_features', default=10)
|
parser.add_argument('--out_features', default=10)
|
||||||
|
|
|
@ -111,12 +111,6 @@ def get_default_parser(strategy, root_dir):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_model_name(args):
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if 'model_name' in arg:
|
|
||||||
return args[i+1]
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_on_cluster(hyperparams):
|
def optimize_on_cluster(hyperparams):
|
||||||
# enable cluster training
|
# enable cluster training
|
||||||
cluster = SlurmCluster(
|
cluster = SlurmCluster(
|
||||||
|
@ -155,57 +149,22 @@ def optimize_on_cluster(hyperparams):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = get_model_name(sys.argv)
|
|
||||||
if model_name is None:
|
|
||||||
model_name = 'model_template'
|
|
||||||
|
|
||||||
# use default args
|
# use default args
|
||||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
parent_parser = get_default_parser(strategy='random_search', root_dir=root_dir)
|
parent_parser = get_default_parser(strategy='random_search', root_dir=root_dir)
|
||||||
|
|
||||||
|
# cluster args not defined inside the model
|
||||||
parent_parser.add_argument('-gpu_partition', type=str)
|
parent_parser.add_argument('-gpu_partition', type=str)
|
||||||
parent_parser.add_argument('-per_experiment_nb_gpus', type=int)
|
parent_parser.add_argument('-per_experiment_nb_gpus', type=int)
|
||||||
|
parent_parser.add_argument('--nb_gpu_nodes', type=int, default=1)
|
||||||
|
|
||||||
# allow model to overwrite or extend args
|
# allow model to overwrite or extend args
|
||||||
TRAINING_MODEL = AVAILABLE_MODELS[model_name]
|
parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir)
|
||||||
parser = TRAINING_MODEL.add_model_specific_args(parent_parser, root_dir)
|
|
||||||
hyperparams = parser.parse_args()
|
hyperparams = parser.parse_args()
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
# RUN TRAINING
|
# RUN TRAINING
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
# run on HPC cluster
|
||||||
# RUN ON CLUSTER
|
print('RUNNING ON SLURM CLUSTER')
|
||||||
if hyperparams.on_cluster:
|
optimize_on_cluster(hyperparams)
|
||||||
# run on HPC cluster
|
|
||||||
print('RUNNING ON SLURM CLUSTER')
|
|
||||||
optimize_on_cluster(hyperparams)
|
|
||||||
|
|
||||||
# RUN ON GPUS
|
|
||||||
elif hyperparams.gpus is not None:
|
|
||||||
# -1 means use all gpus
|
|
||||||
# otherwise use the visible ones
|
|
||||||
if hyperparams.gpus == '-1':
|
|
||||||
gpu_ids = list(range(0, torch.cuda.device_count()))
|
|
||||||
else:
|
|
||||||
gpu_ids = hyperparams.gpus.split(',')
|
|
||||||
|
|
||||||
if hyperparams.interactive:
|
|
||||||
print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {gpu_ids}')
|
|
||||||
main(hyperparams, None, None)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# multiple GPUs on same machine
|
|
||||||
print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}')
|
|
||||||
hyperparams.optimize_parallel_gpu(
|
|
||||||
main_local,
|
|
||||||
gpu_ids=gpu_ids,
|
|
||||||
nb_trials=hyperparams.nb_hopt_trials,
|
|
||||||
nb_workers=len(gpu_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
# RUN ON CPU
|
|
||||||
else:
|
|
||||||
# run on cpu
|
|
||||||
print('RUNNING ON CPU')
|
|
||||||
main(hyperparams, None, None)
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue