Makes automatic optimization a model attribute (#4602)
* Makes automatic optimization a model attribute * Update trainer.py * remove setting property in model * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update trainer.py Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
144a5c9913
commit
e04e7c9ecc
|
@ -164,6 +164,13 @@ class LightningModule(
|
|||
"""
|
||||
return self.device.type == "cuda"
|
||||
|
||||
@property
|
||||
def automatic_optimization(self) -> bool:
|
||||
"""
|
||||
If False you are responsible for calling .backward, .step, zero_grad.
|
||||
"""
|
||||
return True
|
||||
|
||||
def print(self, *args, **kwargs) -> None:
|
||||
r"""
|
||||
Prints only from process 0. Use this in any distributed mode to log only once.
|
||||
|
|
|
@ -35,6 +35,9 @@ class ModelConnector:
|
|||
else:
|
||||
ref_model = model
|
||||
|
||||
automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
|
||||
self.trainer.train_loop.automatic_optimization = automatic_optimization
|
||||
|
||||
for m in [model, ref_model]:
|
||||
m.trainer = self.trainer
|
||||
m.logger = self.trainer.logger
|
||||
|
|
|
@ -137,7 +137,7 @@ class Trainer(
|
|||
amp_backend: str = 'native',
|
||||
amp_level: str = 'O2',
|
||||
distributed_backend: Optional[str] = None,
|
||||
automatic_optimization: bool = True,
|
||||
automatic_optimization: Optional[bool] = None,
|
||||
move_metrics_to_cpu: bool = False,
|
||||
):
|
||||
r"""
|
||||
|
@ -212,7 +212,9 @@ class Trainer(
|
|||
log_every_n_steps: How often to log within steps (defaults to every 50 steps).
|
||||
|
||||
automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad.
|
||||
Meant to be used with multiple optimizers by advanced users.
|
||||
If False you are responsible for calling .backward, .step, zero_grad in LightningModule.
|
||||
This argument has been moved to LightningModule. It is deprecated here in v1.1 and
|
||||
will be removed in v1.3.
|
||||
|
||||
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
|
||||
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
|
||||
|
@ -355,6 +357,14 @@ class Trainer(
|
|||
)
|
||||
|
||||
# init train loop related flags
|
||||
# TODO: deprecate in 1.2.0
|
||||
if automatic_optimization is None:
|
||||
automatic_optimization = True
|
||||
else:
|
||||
rank_zero_warn(
|
||||
"Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
|
||||
"Please use the property on the LightningModule for disabling automatic optimization"
|
||||
)
|
||||
self.train_loop.on_trainer_init(
|
||||
max_epochs,
|
||||
min_epochs,
|
||||
|
|
Loading…
Reference in New Issue