Fix log_graph in TensorBoardLogger (#3092)

This commit is contained in:
Rohit Gupta 2020-08-22 16:05:09 +05:30 committed by GitHub
parent 478abd6b0f
commit 34c88d127b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 23 deletions

View File

@ -70,12 +70,14 @@ class TensorBoardLogger(LightningLoggerBase):
"""
NAME_HPARAMS_FILE = 'hparams.yaml'
def __init__(self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None,
log_graph: bool = True,
**kwargs):
def __init__(
self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None,
log_graph: bool = False,
**kwargs
):
super().__init__()
self._save_dir = save_dir
self._name = name or ''
@ -187,11 +189,8 @@ class TensorBoardLogger(LightningLoggerBase):
input_array = model.example_input_array
if input_array is not None:
self.experiment.add_graph(
model,
model.transfer_batch_to_device(
model.example_input_array, model.device)
)
input_array = model.transfer_batch_to_device(input_array, model.device)
self.experiment.add_graph(model, input_array)
else:
rank_zero_warn('Could not log computational graph since the'
' `model.example_input_array` attribute is not set'

View File

@ -73,15 +73,16 @@ class TestTubeLogger(LightningLoggerBase):
__test__ = False
def __init__(self,
save_dir: str,
name: str = "default",
description: Optional[str] = None,
debug: bool = False,
version: Optional[int] = None,
create_git_tag: bool = False,
log_graph=True):
def __init__(
self,
save_dir: str,
name: str = "default",
description: Optional[str] = None,
debug: bool = False,
version: Optional[int] = None,
create_git_tag: bool = False,
log_graph: bool = False
):
if not _TEST_TUBE_AVAILABLE:
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
' install it with `pip install test-tube`.')

View File

@ -164,9 +164,9 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
if array is passed externaly
"""
model = EvalModelTemplate()
if example_input_array is None:
if example_input_array is not None:
model.example_input_array = None
logger = TensorBoardLogger(tmpdir)
logger = TensorBoardLogger(tmpdir, log_graph=True)
logger.log_graph(model, example_input_array)
@ -174,7 +174,7 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
""" test that log graph throws warning if model.example_input_array is None """
model = EvalModelTemplate()
model.example_input_array = None
logger = TensorBoardLogger(tmpdir)
logger = TensorBoardLogger(tmpdir, log_graph=True)
with pytest.warns(
UserWarning,
match='Could not log computational graph since the `model.example_input_array`'