diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 564efdf1a0..f56612b91a 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -40,8 +40,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): if self.on_gpu: print('running on gpu...') - self.dtype = torch.cuda.FloatTensor - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type(hparams.default_tensor_type) def forward(self, *args, **kwargs): """ diff --git a/pytorch_lightning/utils/arg_parse.py b/pytorch_lightning/utils/arg_parse.py index 5e41656522..5cbeb4f126 100644 --- a/pytorch_lightning/utils/arg_parse.py +++ b/pytorch_lightning/utils/arg_parse.py @@ -49,6 +49,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None parser.add_argument('--gpus', default='0', type=str) parser.add_argument('--single_run_gpu', dest='single_run_gpu', action='store_true') parser.add_argument('--disable_cuda', dest='disable_cuda', action='store_true') + parser.add_argument('--default_tensor_type', default='torch.cuda.FloatTensor', type=str) # run on hpc parser.add_argument('--on_cluster', dest='on_cluster', action='store_true')