Makes automatic optimization a model attribute (#4602)

* Makes automatic optimization a model attribute

* Update trainer.py

* remove setting property in model

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update trainer.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
Justus Schock 2020-11-14 05:43:42 +01:00 committed by GitHub
parent 144a5c9913
commit e04e7c9ecc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 2 deletions

View File

@ -164,6 +164,13 @@ class LightningModule(
"""
return self.device.type == "cuda"
@property
def automatic_optimization(self) -> bool:
"""
If False you are responsible for calling .backward, .step, zero_grad.
"""
return True
def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.

View File

@ -35,6 +35,9 @@ class ModelConnector:
else:
ref_model = model
automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization
for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger

View File

@ -137,7 +137,7 @@ class Trainer(
amp_backend: str = 'native',
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,
automatic_optimization: bool = True,
automatic_optimization: Optional[bool] = None,
move_metrics_to_cpu: bool = False,
):
r"""
@ -212,7 +212,9 @@ class Trainer(
log_every_n_steps: How often to log within steps (defaults to every 50 steps).
automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad.
Meant to be used with multiple optimizers by advanced users.
If False you are responsible for calling .backward, .step, zero_grad in LightningModule.
This argument has been moved to LightningModule. It is deprecated here in v1.1 and
will be removed in v1.3.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
@ -355,6 +357,14 @@ class Trainer(
)
# init train loop related flags
# TODO: deprecate in 1.2.0
if automatic_optimization is None:
automatic_optimization = True
else:
rank_zero_warn(
"Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
"Please use the property on the LightningModule for disabling automatic optimization"
)
self.train_loop.on_trainer_init(
max_epochs,
min_epochs,