[bugfix] Check LightningOptimizer doesn't delete optimizer hooks (#6305)
* update * resolve bug
This commit is contained in:
parent
39231aee1a
commit
7acbd65bcb
|
@ -38,7 +38,7 @@ class LightningOptimizer:
|
|||
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
|
||||
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'}
|
||||
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")}
|
||||
|
||||
# For Horovod
|
||||
if hasattr(optimizer, "skip_synchronize"):
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
from typing import Any
|
||||
from unittest.mock import DEFAULT, patch
|
||||
|
||||
import torch
|
||||
|
@ -303,3 +305,78 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir):
|
|||
lbfgs = model.optimizers()
|
||||
max_iter = lbfgs.param_groups[0]["max_iter"]
|
||||
assert zero_grad.call_count == max_iter
|
||||
|
||||
|
||||
class OptimizerWithHooks(Optimizer):
|
||||
|
||||
def __init__(self, model):
|
||||
self._fwd_handles = []
|
||||
self._bwd_handles = []
|
||||
self.params = []
|
||||
for _, mod in model.named_modules():
|
||||
mod_class = mod.__class__.__name__
|
||||
if mod_class != 'Linear':
|
||||
continue
|
||||
|
||||
handle = mod.register_forward_pre_hook(self._save_input) # save the inputs
|
||||
self._fwd_handles.append(handle) # collect forward-save-input hooks in list
|
||||
handle = mod.register_backward_hook(self._save_grad_output) # save the gradients
|
||||
self._bwd_handles.append(handle) # collect backward-save-grad hook in list
|
||||
|
||||
# save the parameters
|
||||
params = [mod.weight]
|
||||
if mod.bias is not None:
|
||||
params.append(mod.bias)
|
||||
|
||||
# save a param_group for each module
|
||||
d = {'params': params, 'mod': mod, 'layer_type': mod_class}
|
||||
self.params.append(d)
|
||||
|
||||
super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01})
|
||||
|
||||
def _save_input(self, mod, i):
|
||||
"""Saves input of layer"""
|
||||
if mod.training:
|
||||
self.state[mod]['x'] = i[0]
|
||||
|
||||
def _save_grad_output(self, mod, _, grad_output):
|
||||
"""
|
||||
Saves grad on output of layer to
|
||||
grad is scaled with batch_size since gradient is spread over samples in mini batch
|
||||
"""
|
||||
batch_size = grad_output[0].shape[0]
|
||||
if mod.training:
|
||||
self.state[mod]['grad'] = grad_output[0] * batch_size
|
||||
|
||||
def step(self, closure=None):
|
||||
closure()
|
||||
for group in self.param_groups:
|
||||
_ = self.state[group['mod']]['x']
|
||||
_ = self.state[group['mod']]['grad']
|
||||
return True
|
||||
|
||||
|
||||
def test_lightning_optimizer_keeps_hooks(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
count_on_train_batch_start = 0
|
||||
count_on_train_batch_end = 0
|
||||
|
||||
def configure_optimizers(self):
|
||||
return OptimizerWithHooks(self)
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.count_on_train_batch_start += 1
|
||||
optimizer = self.optimizers(use_pl_optimizer=False)
|
||||
assert len(optimizer._fwd_handles) == 1
|
||||
|
||||
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.count_on_train_batch_end += 1
|
||||
del self.trainer._lightning_optimizers
|
||||
gc.collect() # not necessary, just in case
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1)
|
||||
model = TestModel()
|
||||
trainer.fit(model)
|
||||
assert model.count_on_train_batch_start == 4
|
||||
assert model.count_on_train_batch_end == 4
|
||||
|
|
Loading…
Reference in New Issue