feature: Allow str arguments in Trainer.profiler (#3656)
* allow trainer's profiler param to have a str value * add tests * update docs * update exception message * Update CHANGELOG * fix pep8 issues * cleanup test code Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Add deprecation warning if using bool for profiler * Add deprecation tests and move deprecated tests * Remove bool option to profiler from docs * Deprecate bool args to profiler in CHANGELOG * fixup! Add deprecation warning if using bool for profiler * fixup! Add deprecation tests and move deprecated tests * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Implement suggestions, remove whitespace * fixup! Implement suggestions, remove whitespace * Allow bool, str (case insensitive), BaseProfiler * Add info about bool deprecation to trainer * fixup! Add info about bool deprecation to trainer * Move deprecate todo to test_deprecated * Test wrong profiler type, improve error message * fixup! Test wrong profiler type, improve error message * Update pytorch_lightning/trainer/connectors/profiler_connector.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Apply suggestions from code review * Readd bool to profiler types, test cli profiler arg * Remove extra whitespace in doc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update deprecation versions Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
48b6de0c40
commit
c50c225f05
|
@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))
|
||||
|
||||
|
||||
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
|
||||
|
||||
|
||||
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
||||
|
|
|
@ -22,12 +22,12 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
|
|||
Enable simple profiling
|
||||
-----------------------
|
||||
|
||||
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
|
||||
your `Trainer` object.
|
||||
If you only wish to profile the standard actions, you can set `profiler="simple"`
|
||||
when constructing your `Trainer` object.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(..., profiler=True)
|
||||
trainer = Trainer(..., profiler="simple")
|
||||
|
||||
The profiler's results will be printed at the completion of a training `fit()`.
|
||||
|
||||
|
@ -59,6 +59,10 @@ This option uses Python's cProfiler_ to provide a report of time spent on *each*
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(..., profiler="advanced")
|
||||
|
||||
or
|
||||
|
||||
profiler = AdvancedProfiler()
|
||||
trainer = Trainer(..., profiler=profiler)
|
||||
|
||||
|
|
|
@ -1199,14 +1199,11 @@ See the :ref:`profiler documentation <profiler>`. for more details.
|
|||
# default used by the Trainer
|
||||
trainer = Trainer(profiler=None)
|
||||
|
||||
# to profile standard training events
|
||||
trainer = Trainer(profiler=True)
|
||||
# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
|
||||
trainer = Trainer(profiler="simple")
|
||||
|
||||
# equivalent to profiler=True
|
||||
trainer = Trainer(profiler=SimpleProfiler())
|
||||
|
||||
# advanced profiler for function-level stats
|
||||
trainer = Trainer(profiler=AdvancedProfiler())
|
||||
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
|
||||
trainer = Trainer(profiler="advanced")
|
||||
|
||||
progress_bar_refresh_rate
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -11,7 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class ProfilerConnector:
|
||||
|
@ -19,8 +24,27 @@ class ProfilerConnector:
|
|||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_trainer_init(self, profiler):
|
||||
# configure profiler
|
||||
if profiler is True:
|
||||
profiler = SimpleProfiler()
|
||||
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
|
||||
|
||||
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
|
||||
# TODO: Update exception on removal of bool
|
||||
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
|
||||
"are valid values for `Trainer`'s `profiler` parameter. "
|
||||
f"Received {profiler} which is of type {type(profiler)}.")
|
||||
|
||||
if isinstance(profiler, bool):
|
||||
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
|
||||
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
|
||||
DeprecationWarning)
|
||||
if profiler:
|
||||
profiler = SimpleProfiler()
|
||||
elif isinstance(profiler, str):
|
||||
profiler = profiler.lower()
|
||||
if profiler == "simple":
|
||||
profiler = SimpleProfiler()
|
||||
elif profiler == "advanced":
|
||||
profiler = AdvancedProfiler()
|
||||
else:
|
||||
raise ValueError("When passing string value for the `profiler` parameter of"
|
||||
" `Trainer`, it can only be 'simple' or 'advanced'")
|
||||
self.trainer.profiler = profiler or PassThroughProfiler()
|
||||
|
|
|
@ -120,7 +120,7 @@ class Trainer(
|
|||
num_sanity_val_steps: int = 2,
|
||||
truncated_bptt_steps: Optional[int] = None,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
profiler: Optional[Union[BaseProfiler, bool]] = None,
|
||||
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
|
||||
benchmark: bool = False,
|
||||
deterministic: bool = False,
|
||||
reload_dataloaders_every_epoch: bool = False,
|
||||
|
@ -212,7 +212,8 @@ class Trainer(
|
|||
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
||||
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
|
||||
|
||||
profiler: To profile individual steps during training and assist in identifying bottlenecks.
|
||||
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
|
||||
value is deprecated in v1.1 and will be removed in v1.3.
|
||||
|
||||
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
|
||||
|
||||
|
|
|
@ -174,8 +174,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
|||
# if the only arg type is bool
|
||||
if len(arg_types) == 1:
|
||||
use_type = parsing.str_to_bool
|
||||
# if only two args (str, bool)
|
||||
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
|
||||
elif str in arg_types:
|
||||
use_type = parsing.str_to_bool_or_str
|
||||
else:
|
||||
# filter out the bool as we need to use more general
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
from argparse import ArgumentParser
|
||||
import pytest
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.metrics.functional.classification import auc
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -22,6 +26,37 @@ def test_tbd_remove_in_v1_2_0():
|
|||
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')
|
||||
|
||||
|
||||
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
|
||||
@pytest.mark.parametrize(['profiler', 'expected'], [
|
||||
(True, SimpleProfiler),
|
||||
(False, PassThroughProfiler),
|
||||
])
|
||||
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
|
||||
with pytest.deprecated_call(match='will be removed in v1.3'):
|
||||
trainer = Trainer(profiler=profiler)
|
||||
assert isinstance(trainer.profiler, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
|
||||
[
|
||||
('--profiler', True, SimpleProfiler),
|
||||
('--profiler True', True, SimpleProfiler),
|
||||
('--profiler False', False, PassThroughProfiler),
|
||||
],
|
||||
)
|
||||
def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler):
|
||||
cli_args = cli_args.split(' ')
|
||||
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
args = Trainer.parse_argparser(parser)
|
||||
|
||||
assert getattr(args, "profiler") == expected_parsed_arg
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
assert isinstance(trainer.profiler, expected_profiler)
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
# once the module is imported e.g with parsing with pytest it lives in memory
|
||||
if str_module in sys.modules:
|
||||
|
|
|
@ -32,6 +32,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer
|
|||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -1408,3 +1409,32 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
|
|||
trainer.fit(model)
|
||||
expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)]
|
||||
log_metrics_mock.assert_has_calls(expected_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['profiler', 'expected'], [
|
||||
(None, PassThroughProfiler),
|
||||
(SimpleProfiler(), SimpleProfiler),
|
||||
(AdvancedProfiler(), AdvancedProfiler),
|
||||
('simple', SimpleProfiler),
|
||||
('Simple', SimpleProfiler),
|
||||
('advanced', AdvancedProfiler),
|
||||
])
|
||||
def test_trainer_profiler_correct_args(profiler, expected):
|
||||
kwargs = {'profiler': profiler} if profiler is not None else {}
|
||||
trainer = Trainer(**kwargs)
|
||||
assert isinstance(trainer.profiler, expected)
|
||||
|
||||
|
||||
def test_trainer_profiler_incorrect_str_arg():
|
||||
with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"):
|
||||
Trainer(profiler="unknown_profiler")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('profiler', (
|
||||
42, [42], {"a": 42}, torch.tensor(42), Trainer(),
|
||||
))
|
||||
def test_trainer_profiler_incorrect_arg_type(profiler):
|
||||
with pytest.raises(MisconfigurationException,
|
||||
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
|
||||
r"are valid values for `Trainer`'s `profiler` parameter. *"):
|
||||
Trainer(profiler=profiler)
|
||||
|
|
Loading…
Reference in New Issue