Deprecate `AbstractProfiler` in favor of `BaseProfiler` (#12106)
This commit is contained in:
parent
0b682b807a
commit
eff67d7a02
|
@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated passing `weights_save_path` to the `Trainer` constructor in favor of adding the `ModelCheckpoint` callback with `dirpath` directly to the list of callbacks ([#12084](https://github.com/PyTorchLightning/pytorch-lightning/pull/12084))
|
||||
|
||||
|
||||
- Deprecated `pytorch_lightning.profiler.AbstractProfiler` in favor of `pytorch_lightning.profiler.BaseProfiler` ([#12106](https://github.com/PyTorchLightning/pytorch-lightning/pull/12106))
|
||||
|
||||
|
||||
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,14 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AbstractProfiler(ABC):
|
||||
"""Specification of a profiler."""
|
||||
"""Specification of a profiler.
|
||||
|
||||
See deprecation warning below
|
||||
|
||||
.. deprecated:: v1.6
|
||||
`AbstractProfiler` was deprecated in v1.6 and will be removed in v1.8.
|
||||
Please use `BaseProfiler` instead
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, action_name: str) -> None:
|
||||
|
@ -49,7 +56,7 @@ class AbstractProfiler(ABC):
|
|||
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
|
||||
|
||||
|
||||
class BaseProfiler(AbstractProfiler):
|
||||
class BaseProfiler(ABC):
|
||||
"""If you wish to write a custom profiler, you should inherit from this class."""
|
||||
|
||||
def __init__(
|
||||
|
@ -65,6 +72,17 @@ class BaseProfiler(AbstractProfiler):
|
|||
self._local_rank: Optional[int] = None
|
||||
self._stage: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def start(self, action_name: str) -> None:
|
||||
"""Defines how to start recording an action."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, action_name: str) -> None:
|
||||
"""Defines how to record the duration once an action is complete."""
|
||||
|
||||
def summary(self) -> str:
|
||||
return ""
|
||||
|
||||
@contextmanager
|
||||
def profile(self, action_name: str) -> Generator:
|
||||
"""Yields a context manager to encapsulate the scope of a profiled action.
|
||||
|
@ -180,15 +198,6 @@ class BaseProfiler(AbstractProfiler):
|
|||
def __del__(self) -> None:
|
||||
self.teardown(stage=self._stage)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self, action_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def summary(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def local_rank(self) -> int:
|
||||
return 0 if self._local_rank is None else self._local_rank
|
||||
|
@ -205,6 +214,3 @@ class PassThroughProfiler(BaseProfiler):
|
|||
|
||||
def stop(self, action_name: str) -> None:
|
||||
pass
|
||||
|
||||
def summary(self) -> str:
|
||||
return ""
|
||||
|
|
|
@ -78,6 +78,3 @@ class XLAProfiler(BaseProfiler):
|
|||
else:
|
||||
self._step_recoding_map[action_name] += 1
|
||||
return self._step_recoding_map[action_name]
|
||||
|
||||
def summary(self) -> str:
|
||||
return ""
|
||||
|
|
|
@ -35,7 +35,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde
|
|||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
|
||||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
|
||||
|
@ -705,3 +705,7 @@ def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
|
|||
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_8_0_abstract_profiler():
|
||||
assert "`AbstractProfiler` was deprecated in v1.6" in AbstractProfiler.__doc__
|
||||
|
|
Loading…
Reference in New Issue