added option to change default tensor

This commit is contained in:
William Falcon 2019-05-13 19:30:06 -04:00
parent 5a7ad19403
commit e3425ec6a0
2 changed files with 2 additions and 2 deletions

View File

@ -40,8 +40,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
if self.on_gpu: if self.on_gpu:
print('running on gpu...') print('running on gpu...')
self.dtype = torch.cuda.FloatTensor torch.set_default_tensor_type(hparams.default_tensor_type)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" """

View File

@ -49,6 +49,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
parser.add_argument('--gpus', default='0', type=str) parser.add_argument('--gpus', default='0', type=str)
parser.add_argument('--single_run_gpu', dest='single_run_gpu', action='store_true') parser.add_argument('--single_run_gpu', dest='single_run_gpu', action='store_true')
parser.add_argument('--disable_cuda', dest='disable_cuda', action='store_true') parser.add_argument('--disable_cuda', dest='disable_cuda', action='store_true')
parser.add_argument('--default_tensor_type', default='torch.cuda.FloatTensor', type=str)
# run on hpc # run on hpc
parser.add_argument('--on_cluster', dest='on_cluster', action='store_true') parser.add_argument('--on_cluster', dest='on_cluster', action='store_true')