Un-balanced logging properly supported (#5119)

* resolve bug

* clean code

* resolve comments

* Update tests/trainer/optimization/test_multiple_optimizers.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* resolve another bug

* add comments

* use abs to find diff

* update

* resolve flake8

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
chaton 2020-12-16 22:07:17 +01:00 committed by Jirka Borovec
parent 58a2993766
commit 13bbf4b3f2
2 changed files with 78 additions and 11 deletions

View File

@ -91,11 +91,13 @@ class HookResultStore:
random_key = list(result.keys())[-1]
return result["meta"][random_key]["dataloader_idx"] is not None
def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict:
def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict:
results = {}
add_dataloader_idx = self.check_dataloader_idx(latest_result)
func = getattr(latest_result, func_name)
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
for opt_idx in latest_result_opt:
latest_result = latest_result_opt[opt_idx]
add_dataloader_idx = self.check_dataloader_idx(latest_result)
func = getattr(latest_result, func_name)
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
return results
def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
@ -156,6 +158,7 @@ class HookResultStore:
assert isinstance(result, Result)
if dataloader_idx is None:
dataloader_idx = 0
if extra_info is None:
extra_info = {}
@ -166,6 +169,7 @@ class HookResultStore:
if dataloader_idx not in self._internals:
self._internals[dataloader_idx] = {}
self._internals_reduced[dataloader_idx] = defaultdict(dict)
self._latest_ref[dataloader_idx] = {}
# extract infos
opt_idx = extra_info["opt_idx"]
@ -173,7 +177,7 @@ class HookResultStore:
self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)
self._latest_ref[dataloader_idx] = result
self._latest_ref[dataloader_idx][opt_idx] = result
# [dataloader_idx] is a list
else:
@ -181,7 +185,11 @@ class HookResultStore:
self._internals.setdefault(dataloader_idx, [])
self._internals[dataloader_idx].append(result)
self._latest_ref[dataloader_idx] = result
if dataloader_idx not in self._latest_ref:
self._latest_ref[dataloader_idx] = {}
self._latest_ref[dataloader_idx][0] = {}
self._latest_ref[dataloader_idx][0] = result
def auto_reduce_results_on_epoch_end(self) -> None:
"""
@ -206,13 +214,9 @@ class HookResultStore:
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]
num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1
assert num_batch_idx >= 0
batch_indexes = self._internals[dl_idx][num_opt_idx].keys()
# reduce across time first
time_reduced_outputs = []
for batch_idx in batch_indexes:
for batch_idx in opt_outputs.keys():
tbptt_outs = opt_outputs[batch_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
if len(tbptt_outs) > 1:

View File

@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests to ensure that the behaviours related to multiple optimizers works
"""
import torch
import pytorch_lightning as pl
from tests.base.boring_model import BoringModel
def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
"""
This tests ensures reduction works in un-balanced logging settings
"""
class TestModel(BoringModel):
loss_1 = []
loss_2 = []
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
if optimizer_idx == 0 and self.trainer.global_step > 10:
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
self.loss_1.append(loss.detach().clone())
elif optimizer_idx == 1:
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
self.loss_2.append(loss.detach().clone())
return {"loss": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001)
return [optimizer, optimizer2]
model = TestModel()
model.training_epoch_end = None
# Initialize a trainer
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
)
trainer.fit(model)
assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
# test loss are properly reduced
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6