From e3425ec6a0ef8241f4e9ebb4a701db40bf0222fb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 13 May 2019 19:30:06 -0400 Subject: [PATCH] added option to change default tensor --- pytorch_lightning/root_module/root_module.py | 3 +-- pytorch_lightning/utils/arg_parse.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 564efdf1a0..f56612b91a 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -40,8 +40,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): if self.on_gpu: print('running on gpu...') - self.dtype = torch.cuda.FloatTensor - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type(hparams.default_tensor_type) def forward(self, *args, **kwargs): """ diff --git a/pytorch_lightning/utils/arg_parse.py b/pytorch_lightning/utils/arg_parse.py index 5e41656522..5cbeb4f126 100644 --- a/pytorch_lightning/utils/arg_parse.py +++ b/pytorch_lightning/utils/arg_parse.py @@ -49,6 +49,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None parser.add_argument('--gpus', default='0', type=str) parser.add_argument('--single_run_gpu', dest='single_run_gpu', action='store_true') parser.add_argument('--disable_cuda', dest='disable_cuda', action='store_true') + parser.add_argument('--default_tensor_type', default='torch.cuda.FloatTensor', type=str) # run on hpc parser.add_argument('--on_cluster', dest='on_cluster', action='store_true')