Deprecate `BaseProfiler.profile_iterable` (#12102)
This commit is contained in:
parent
61dd5e4d5e
commit
f5304897ce
|
@ -430,6 +430,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
|
||||
|
||||
|
||||
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
|
||||
|
|
|
@ -20,6 +20,7 @@ from pathlib import Path
|
|||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
|
||||
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -83,6 +84,14 @@ class BaseProfiler(AbstractProfiler):
|
|||
self.stop(action_name)
|
||||
|
||||
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
|
||||
"""Profiles over each value of an iterable.
|
||||
|
||||
See deprecation message below.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
rank_zero_deprecation("`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.")
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in v1.8.0."""
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -33,6 +34,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde
|
|||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
|
||||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
|
||||
|
@ -608,3 +610,45 @@ def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
|
|||
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):
|
||||
"""Ensure the reported durations are reasonably accurate."""
|
||||
|
||||
def _sleep_generator(durations):
|
||||
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long
|
||||
it takes to call __next__"""
|
||||
for duration in durations:
|
||||
time.sleep(duration)
|
||||
yield duration
|
||||
|
||||
def _get_python_cprofile_total_duration(profile):
|
||||
return sum(x.inlinetime for x in profile.getstats())
|
||||
|
||||
simple_profiler = SimpleProfiler()
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
for _ in simple_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
# we exclude the last item in the recorded durations since that's when StopIteration is raised
|
||||
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
|
||||
|
||||
advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler")
|
||||
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
for _ in advanced_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
|
||||
expected_total_duration = np.sum(expected)
|
||||
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
|
||||
|
|
|
@ -67,19 +67,6 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list)
|
|||
np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list):
|
||||
"""Ensure the reported durations are reasonably accurate."""
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
for _ in simple_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
# we exclude the last item in the recorded durations since that's when StopIteration is raised
|
||||
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
|
||||
|
||||
|
||||
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
|
||||
"""Ensure that the profiler doesn't introduce too much overhead during training."""
|
||||
for _ in range(n_iter):
|
||||
|
@ -289,20 +276,6 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
|
|||
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list):
|
||||
"""Ensure the reported durations are reasonably accurate."""
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
for _ in advanced_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
|
||||
expected_total_duration = np.sum(expected)
|
||||
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
|
||||
"""ensure that the profiler doesn't introduce too much overhead during training."""
|
||||
|
|
Loading…
Reference in New Issue