Un-balanced logging properly supported (#5119)
* resolve bug * clean code * resolve comments * Update tests/trainer/optimization/test_multiple_optimizers.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * resolve another bug * add comments * use abs to find diff * update * resolve flake8 Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
58a2993766
commit
13bbf4b3f2
|
@ -91,11 +91,13 @@ class HookResultStore:
|
|||
random_key = list(result.keys())[-1]
|
||||
return result["meta"][random_key]["dataloader_idx"] is not None
|
||||
|
||||
def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict:
|
||||
def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict:
|
||||
results = {}
|
||||
add_dataloader_idx = self.check_dataloader_idx(latest_result)
|
||||
func = getattr(latest_result, func_name)
|
||||
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
|
||||
for opt_idx in latest_result_opt:
|
||||
latest_result = latest_result_opt[opt_idx]
|
||||
add_dataloader_idx = self.check_dataloader_idx(latest_result)
|
||||
func = getattr(latest_result, func_name)
|
||||
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
|
||||
return results
|
||||
|
||||
def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
|
||||
|
@ -156,6 +158,7 @@ class HookResultStore:
|
|||
assert isinstance(result, Result)
|
||||
if dataloader_idx is None:
|
||||
dataloader_idx = 0
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
|
||||
|
@ -166,6 +169,7 @@ class HookResultStore:
|
|||
if dataloader_idx not in self._internals:
|
||||
self._internals[dataloader_idx] = {}
|
||||
self._internals_reduced[dataloader_idx] = defaultdict(dict)
|
||||
self._latest_ref[dataloader_idx] = {}
|
||||
|
||||
# extract infos
|
||||
opt_idx = extra_info["opt_idx"]
|
||||
|
@ -173,7 +177,7 @@ class HookResultStore:
|
|||
|
||||
self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)
|
||||
|
||||
self._latest_ref[dataloader_idx] = result
|
||||
self._latest_ref[dataloader_idx][opt_idx] = result
|
||||
|
||||
# [dataloader_idx] is a list
|
||||
else:
|
||||
|
@ -181,7 +185,11 @@ class HookResultStore:
|
|||
self._internals.setdefault(dataloader_idx, [])
|
||||
self._internals[dataloader_idx].append(result)
|
||||
|
||||
self._latest_ref[dataloader_idx] = result
|
||||
if dataloader_idx not in self._latest_ref:
|
||||
self._latest_ref[dataloader_idx] = {}
|
||||
self._latest_ref[dataloader_idx][0] = {}
|
||||
|
||||
self._latest_ref[dataloader_idx][0] = result
|
||||
|
||||
def auto_reduce_results_on_epoch_end(self) -> None:
|
||||
"""
|
||||
|
@ -206,13 +214,9 @@ class HookResultStore:
|
|||
# TODO: How to start training in middle of epoch
|
||||
opt_outputs = epoch_metrics[opt_idx]
|
||||
|
||||
num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1
|
||||
assert num_batch_idx >= 0
|
||||
batch_indexes = self._internals[dl_idx][num_opt_idx].keys()
|
||||
|
||||
# reduce across time first
|
||||
time_reduced_outputs = []
|
||||
for batch_idx in batch_indexes:
|
||||
for batch_idx in opt_outputs.keys():
|
||||
tbptt_outs = opt_outputs[batch_idx]
|
||||
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
|
||||
if len(tbptt_outs) > 1:
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests to ensure that the behaviours related to multiple optimizers works
|
||||
"""
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
|
||||
"""
|
||||
This tests ensures reduction works in un-balanced logging settings
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
|
||||
loss_1 = []
|
||||
loss_2 = []
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
if optimizer_idx == 0 and self.trainer.global_step > 10:
|
||||
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
|
||||
self.loss_1.append(loss.detach().clone())
|
||||
elif optimizer_idx == 1:
|
||||
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
|
||||
self.loss_2.append(loss.detach().clone())
|
||||
return {"loss": loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
|
||||
optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001)
|
||||
return [optimizer, optimizer2]
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
|
||||
# Initialize a trainer
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
|
||||
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
|
||||
# test loss are properly reduced
|
||||
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
|
||||
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6
|
Loading…
Reference in New Issue