updated args
This commit is contained in:
parent
d4ca295762
commit
684dfd0a38
|
@ -92,13 +92,18 @@ def main(hparams, cluster, results_dict):
|
|||
mode=hparams.model_save_monitor_mode
|
||||
)
|
||||
|
||||
# gpus are ; separated for inside a node and , within nodes
|
||||
gpu_list = None
|
||||
if hparams.gpus is not None:
|
||||
gpu_list = map(int, hparams.gpus.split(';'))
|
||||
|
||||
# configure trainer
|
||||
trainer = Trainer(
|
||||
experiment=exp,
|
||||
cluster=cluster,
|
||||
checkpoint_callback=checkpoint,
|
||||
early_stop_callback=early_stop,
|
||||
gpus=hparams.gpus
|
||||
gpus=gpu_list
|
||||
)
|
||||
|
||||
# train model
|
||||
|
|
Loading…
Reference in New Issue