clean up unused attributes in LightningModule (#8259)

This commit is contained in:
Adrian Wälchli 2021-07-06 10:13:09 +02:00 committed by GitHub
parent a7e21bd5ad
commit 9eda520bee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 9 deletions

View File

@ -91,8 +91,6 @@ class LightningModule(
# torch/nn/modules/module.py#L227)
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")
self._loaded_optimizer_states_dict = {}
# pointer to the trainer object
self.trainer = None
@ -109,13 +107,15 @@ class LightningModule(
self._example_input_array = None
self._datamodule = None
self._current_fx_name: Optional[str] = None
self._running_manual_backward: bool = False
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = dict()
self._metric_attributes: Optional[Dict[int, str]] = None
# deprecated, will be removed in 1.6
self._loaded_optimizer_states_dict = {}
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
"""
Returns the optimizer(s) that are being used during training. Useful for manual optimization.
@ -1450,9 +1450,7 @@ class LightningModule(
self._verify_is_manual_optimization('manual_backward')
# backward
self._running_manual_backward = True
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self._running_manual_backward = False
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
@ -1470,8 +1468,7 @@ class LightningModule(
def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
"""
if self.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)
loss.backward(*args, **kwargs)
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""

View File

@ -24,7 +24,5 @@ class ModelConnector:
for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m._device_type = str(self.trainer._device_type)
m._distrib_type = str(self.trainer._distrib_type)
m.use_amp = self.trainer.amp_backend is not None
m.precision = self.trainer.precision