Deprecate `AbstractProfiler` in favor of `BaseProfiler` (#12106)

This commit is contained in:
Akash Kwatra 2022-03-04 18:35:57 -08:00 committed by GitHub
parent 0b682b807a
commit eff67d7a02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 18 deletions

View File

@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated passing `weights_save_path` to the `Trainer` constructor in favor of adding the `ModelCheckpoint` callback with `dirpath` directly to the list of callbacks ([#12084](https://github.com/PyTorchLightning/pytorch-lightning/pull/12084)) - Deprecated passing `weights_save_path` to the `Trainer` constructor in favor of adding the `ModelCheckpoint` callback with `dirpath` directly to the list of callbacks ([#12084](https://github.com/PyTorchLightning/pytorch-lightning/pull/12084))
- Deprecated `pytorch_lightning.profiler.AbstractProfiler` in favor of `pytorch_lightning.profiler.BaseProfiler` ([#12106](https://github.com/PyTorchLightning/pytorch-lightning/pull/12106))
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102)) - Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))

View File

@ -26,7 +26,14 @@ log = logging.getLogger(__name__)
class AbstractProfiler(ABC): class AbstractProfiler(ABC):
"""Specification of a profiler.""" """Specification of a profiler.
See deprecation warning below
.. deprecated:: v1.6
`AbstractProfiler` was deprecated in v1.6 and will be removed in v1.8.
Please use `BaseProfiler` instead
"""
@abstractmethod @abstractmethod
def start(self, action_name: str) -> None: def start(self, action_name: str) -> None:
@ -49,7 +56,7 @@ class AbstractProfiler(ABC):
"""Execute arbitrary post-profiling tear-down steps as defined by subclass.""" """Execute arbitrary post-profiling tear-down steps as defined by subclass."""
class BaseProfiler(AbstractProfiler): class BaseProfiler(ABC):
"""If you wish to write a custom profiler, you should inherit from this class.""" """If you wish to write a custom profiler, you should inherit from this class."""
def __init__( def __init__(
@ -65,6 +72,17 @@ class BaseProfiler(AbstractProfiler):
self._local_rank: Optional[int] = None self._local_rank: Optional[int] = None
self._stage: Optional[str] = None self._stage: Optional[str] = None
@abstractmethod
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
@abstractmethod
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""
def summary(self) -> str:
return ""
@contextmanager @contextmanager
def profile(self, action_name: str) -> Generator: def profile(self, action_name: str) -> Generator:
"""Yields a context manager to encapsulate the scope of a profiled action. """Yields a context manager to encapsulate the scope of a profiled action.
@ -180,15 +198,6 @@ class BaseProfiler(AbstractProfiler):
def __del__(self) -> None: def __del__(self) -> None:
self.teardown(stage=self._stage) self.teardown(stage=self._stage)
def start(self, action_name: str) -> None:
raise NotImplementedError
def stop(self, action_name: str) -> None:
raise NotImplementedError
def summary(self) -> str:
raise NotImplementedError
@property @property
def local_rank(self) -> int: def local_rank(self) -> int:
return 0 if self._local_rank is None else self._local_rank return 0 if self._local_rank is None else self._local_rank
@ -205,6 +214,3 @@ class PassThroughProfiler(BaseProfiler):
def stop(self, action_name: str) -> None: def stop(self, action_name: str) -> None:
pass pass
def summary(self) -> str:
return ""

View File

@ -78,6 +78,3 @@ class XLAProfiler(BaseProfiler):
else: else:
self._step_recoding_map[action_name] += 1 self._step_recoding_map[action_name] += 1
return self._step_recoding_map[action_name] return self._step_recoding_map[action_name]
def summary(self) -> str:
return ""

View File

@ -35,7 +35,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, SimpleProfiler
from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.enums import DeviceType, DistributedType
@ -705,3 +705,7 @@ def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead." " v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
): ):
trainer.fit(model) trainer.fit(model)
def test_v1_8_0_abstract_profiler():
assert "`AbstractProfiler` was deprecated in v1.6" in AbstractProfiler.__doc__