diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index b3cdbf7569..4d8cb4d64d 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -70,12 +70,14 @@ class TensorBoardLogger(LightningLoggerBase): """ NAME_HPARAMS_FILE = 'hparams.yaml' - def __init__(self, - save_dir: str, - name: Optional[str] = "default", - version: Optional[Union[int, str]] = None, - log_graph: bool = True, - **kwargs): + def __init__( + self, + save_dir: str, + name: Optional[str] = "default", + version: Optional[Union[int, str]] = None, + log_graph: bool = False, + **kwargs + ): super().__init__() self._save_dir = save_dir self._name = name or '' @@ -187,11 +189,8 @@ class TensorBoardLogger(LightningLoggerBase): input_array = model.example_input_array if input_array is not None: - self.experiment.add_graph( - model, - model.transfer_batch_to_device( - model.example_input_array, model.device) - ) + input_array = model.transfer_batch_to_device(input_array, model.device) + self.experiment.add_graph(model, input_array) else: rank_zero_warn('Could not log computational graph since the' ' `model.example_input_array` attribute is not set' diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 36928c8dbf..a1adacf654 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -73,15 +73,16 @@ class TestTubeLogger(LightningLoggerBase): __test__ = False - def __init__(self, - save_dir: str, - name: str = "default", - description: Optional[str] = None, - debug: bool = False, - version: Optional[int] = None, - create_git_tag: bool = False, - log_graph=True): - + def __init__( + self, + save_dir: str, + name: str = "default", + description: Optional[str] = None, + debug: bool = False, + version: Optional[int] = None, + create_git_tag: bool = False, + log_graph: bool = False + ): if not _TEST_TUBE_AVAILABLE: raise ImportError('You want to use `test_tube` logger which is not installed yet,' ' install it with `pip install test-tube`.') diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 90a5c8df5b..7967e0c095 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -164,9 +164,9 @@ def test_tensorboard_log_graph(tmpdir, example_input_array): if array is passed externaly """ model = EvalModelTemplate() - if example_input_array is None: + if example_input_array is not None: model.example_input_array = None - logger = TensorBoardLogger(tmpdir) + logger = TensorBoardLogger(tmpdir, log_graph=True) logger.log_graph(model, example_input_array) @@ -174,7 +174,7 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): """ test that log graph throws warning if model.example_input_array is None """ model = EvalModelTemplate() model.example_input_array = None - logger = TensorBoardLogger(tmpdir) + logger = TensorBoardLogger(tmpdir, log_graph=True) with pytest.warns( UserWarning, match='Could not log computational graph since the `model.example_input_array`'