remove default tensor
This commit is contained in:
parent
2eba85d02c
commit
cd11b7de98
|
@ -9,7 +9,6 @@ from pytorch_lightning.root_module.optimization import OptimizerConfig
|
||||||
from pytorch_lightning.root_module.hooks import ModelHooks
|
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
|
@ -40,10 +39,6 @@ class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||||
self._val_dataloader = None
|
self._val_dataloader = None
|
||||||
self._test_dataloader = None
|
self._test_dataloader = None
|
||||||
|
|
||||||
if self.on_gpu:
|
|
||||||
print('running on gpu...')
|
|
||||||
torch.set_default_tensor_type(hparams.default_tensor_type)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Expand model in into whatever you need.
|
Expand model in into whatever you need.
|
||||||
|
|
Loading…
Reference in New Issue