From d4ca295762e08ecbaecdf6d56fa535c7a3cc452f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 18:51:41 -0400 Subject: [PATCH] updated args --- docs/source/examples/fully_featured_trainer.py | 1 + pytorch_lightning/models/trainer.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/examples/fully_featured_trainer.py b/docs/source/examples/fully_featured_trainer.py index 383e44eed6..ce213aa201 100644 --- a/docs/source/examples/fully_featured_trainer.py +++ b/docs/source/examples/fully_featured_trainer.py @@ -98,6 +98,7 @@ def main(hparams, cluster, results_dict): cluster=cluster, checkpoint_callback=checkpoint, early_stop_callback=early_stop, + gpus=hparams.gpus ) # train model diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 56b5f43adf..6a622699a7 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -22,7 +22,7 @@ class Trainer(TrainerIO): cluster=None, process_position=0, current_gpu_name=0, - on_gpu=False, + gpus=None, enable_tqdm=True, overfit_pct=0.0, track_grad_norm=-1, @@ -43,7 +43,7 @@ class Trainer(TrainerIO): self.enable_early_stop = enable_early_stop self.track_grad_norm = track_grad_norm self.fast_dev_run = fast_dev_run - self.on_gpu = on_gpu + self.on_gpu = gpus is not None and torch.cuda.is_available() self.enable_tqdm = enable_tqdm self.experiment = experiment self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) @@ -63,8 +63,10 @@ class Trainer(TrainerIO): self.lr_schedulers = [] self.amp_level = amp_level self.check_grad_nans = check_grad_nans - self.data_parallel_device_ids = [0] - self.data_parallel = False + self.data_parallel_device_ids = gpus + self.data_parallel = gpus is not None and len(gpus) > 0 + + pdb.set_trace() # training state self.optimizers = None