Fix tests for new tensorboard >= 2.6 (#8789)
This commit is contained in:
parent
e7440d68c3
commit
346cef2c3c
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
@ -20,14 +21,19 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
from pytorch_lightning.utilities.imports import _compare_version
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
_compare_version("tensorboard", operator.ge, "2.6.0"), reason="cannot import EventAccumulator in >= 2.6.0"
|
||||||
|
)
|
||||||
def test_tensorboard_hparams_reload(tmpdir):
|
def test_tensorboard_hparams_reload(tmpdir):
|
||||||
|
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||||
|
|
||||||
class CustomModel(BoringModel):
|
class CustomModel(BoringModel):
|
||||||
def __init__(self, b1=0.5, b2=0.999):
|
def __init__(self, b1=0.5, b2=0.999):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue