Don't import torch_xla.debug for torch-xla<1.8 (#10836)

This commit is contained in:
Kaushik B 2021-12-06 12:01:38 +05:30 committed by GitHub
parent 3d59a2faff
commit 6599ced17d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 4 deletions

View File

@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))
- Fixed importing `torch_xla.debug` for `torch-xla<1.8` ([#10836](https://github.com/PyTorchLightning/pytorch-lightning/pull/10836))
- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))

View File

@ -42,9 +42,10 @@ import logging
from typing import Dict
from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TPU_AVAILABLE:
if _TPU_AVAILABLE and _TORCH_GREATER_EQUAL_1_8:
import torch_xla.debug.profiler as xp
log = logging.getLogger(__name__)
@ -65,6 +66,10 @@ class XLAProfiler(BaseProfiler):
def __init__(self, port: int = 9012) -> None:
"""This Profiler will help you debug and optimize training workload performance for your models using Cloud
TPU performance tools."""
if not _TPU_AVAILABLE:
raise MisconfigurationException("`XLAProfiler` is only supported on TPUs")
if not _TORCH_GREATER_EQUAL_1_8:
raise MisconfigurationException("`XLAProfiler` is only supported with `torch-xla >= 1.8`")
super().__init__(dirpath=None, filename=None)
self.port = port
self._recording_map: Dict = {}

View File

@ -18,14 +18,16 @@ import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.profiler import XLAProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if _TPU_AVAILABLE:
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
if _TORCH_GREATER_EQUAL_1_8:
import torch_xla.debug.profiler as xp
@RunIf(tpu=True)
def test_xla_profiler_instance(tmpdir):