Fix log_graph in TensorBoardLogger (#3092)
This commit is contained in:
parent
478abd6b0f
commit
34c88d127b
|
@ -70,12 +70,14 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
"""
|
||||
NAME_HPARAMS_FILE = 'hparams.yaml'
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None,
|
||||
log_graph: bool = True,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None,
|
||||
log_graph: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self._save_dir = save_dir
|
||||
self._name = name or ''
|
||||
|
@ -187,11 +189,8 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
input_array = model.example_input_array
|
||||
|
||||
if input_array is not None:
|
||||
self.experiment.add_graph(
|
||||
model,
|
||||
model.transfer_batch_to_device(
|
||||
model.example_input_array, model.device)
|
||||
)
|
||||
input_array = model.transfer_batch_to_device(input_array, model.device)
|
||||
self.experiment.add_graph(model, input_array)
|
||||
else:
|
||||
rank_zero_warn('Could not log computational graph since the'
|
||||
' `model.example_input_array` attribute is not set'
|
||||
|
|
|
@ -73,15 +73,16 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
|
||||
__test__ = False
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
name: str = "default",
|
||||
description: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
version: Optional[int] = None,
|
||||
create_git_tag: bool = False,
|
||||
log_graph=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
name: str = "default",
|
||||
description: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
version: Optional[int] = None,
|
||||
create_git_tag: bool = False,
|
||||
log_graph: bool = False
|
||||
):
|
||||
if not _TEST_TUBE_AVAILABLE:
|
||||
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
|
||||
' install it with `pip install test-tube`.')
|
||||
|
|
|
@ -164,9 +164,9 @@ def test_tensorboard_log_graph(tmpdir, example_input_array):
|
|||
if array is passed externaly
|
||||
"""
|
||||
model = EvalModelTemplate()
|
||||
if example_input_array is None:
|
||||
if example_input_array is not None:
|
||||
model.example_input_array = None
|
||||
logger = TensorBoardLogger(tmpdir)
|
||||
logger = TensorBoardLogger(tmpdir, log_graph=True)
|
||||
logger.log_graph(model, example_input_array)
|
||||
|
||||
|
||||
|
@ -174,7 +174,7 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
|
|||
""" test that log graph throws warning if model.example_input_array is None """
|
||||
model = EvalModelTemplate()
|
||||
model.example_input_array = None
|
||||
logger = TensorBoardLogger(tmpdir)
|
||||
logger = TensorBoardLogger(tmpdir, log_graph=True)
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='Could not log computational graph since the `model.example_input_array`'
|
||||
|
|
Loading…
Reference in New Issue