update usage of deprecated profiler (#5010)
* drop deprecated profiler * lut Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
This commit is contained in:
parent
cdbddbe99f
commit
77fb425dd4
|
@ -237,7 +237,7 @@ def run_cli():
|
|||
help='seed for initializing training.')
|
||||
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
|
||||
parser.set_defaults(
|
||||
profiler=True,
|
||||
profiler="simple",
|
||||
deterministic=True,
|
||||
max_epochs=90,
|
||||
)
|
||||
|
|
|
@ -18,6 +18,11 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, Simple
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
PROFILERS = {
|
||||
"simple": SimpleProfiler,
|
||||
"advanced": AdvancedProfiler,
|
||||
}
|
||||
|
||||
|
||||
class ProfilerConnector:
|
||||
|
||||
|
@ -28,9 +33,9 @@ class ProfilerConnector:
|
|||
|
||||
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
|
||||
# TODO: Update exception on removal of bool
|
||||
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
|
||||
"are valid values for `Trainer`'s `profiler` parameter. "
|
||||
f"Received {profiler} which is of type {type(profiler)}.")
|
||||
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
|
||||
" are valid values for `Trainer`'s `profiler` parameter."
|
||||
f" Received {profiler} which is of type {type(profiler)}.")
|
||||
|
||||
if isinstance(profiler, bool):
|
||||
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
|
||||
|
@ -39,11 +44,9 @@ class ProfilerConnector:
|
|||
if profiler:
|
||||
profiler = SimpleProfiler()
|
||||
elif isinstance(profiler, str):
|
||||
profiler = profiler.lower()
|
||||
if profiler == "simple":
|
||||
profiler = SimpleProfiler()
|
||||
elif profiler == "advanced":
|
||||
profiler = AdvancedProfiler()
|
||||
if profiler.lower() in PROFILERS:
|
||||
profiler_class = PROFILERS[profiler.lower()]
|
||||
profiler = profiler_class()
|
||||
else:
|
||||
raise ValueError("When passing string value for the `profiler` parameter of"
|
||||
" `Trainer`, it can only be 'simple' or 'advanced'")
|
||||
|
|
|
@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg():
|
|||
))
|
||||
def test_trainer_profiler_incorrect_arg_type(profiler):
|
||||
with pytest.raises(MisconfigurationException,
|
||||
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
|
||||
r"are valid values for `Trainer`'s `profiler` parameter. *"):
|
||||
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
|
||||
r" are valid values for `Trainer`'s `profiler` parameter. *"):
|
||||
Trainer(profiler=profiler)
|
||||
|
|
Loading…
Reference in New Issue