fix collecting training_step outputs (#8613)
This commit is contained in:
parent
5789e9f5e4
commit
529c42f848
|
@ -70,19 +70,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
|
||||
- Fixed horovod auto-detection when horovod is not installed and the launcher is `mpirun` ([#8610](https://github.com/PyTorchLightning/pytorch-lightning/pull/8610))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
|
||||
|
||||
|
||||
- Fixed references for `ResultCollection.extra` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))
|
||||
|
||||
|
||||
- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621))
|
||||
|
||||
|
||||
- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
|
||||
- Fixed horovod auto-detection when horovod is not installed and the launcher is `mpirun` ([#8610](https://github.com/PyTorchLightning/pytorch-lightning/pull/8610))
|
||||
|
||||
-
|
||||
|
||||
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from functools import partial, update_wrapper
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple
|
||||
|
||||
|
@ -144,12 +145,12 @@ class TrainingBatchLoop(Loop):
|
|||
|
||||
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
|
||||
if result:
|
||||
self.batch_outputs[opt_idx].append(result.training_step_output)
|
||||
self.batch_outputs[opt_idx].append(copy(result.training_step_output))
|
||||
else:
|
||||
# in manual optimization, there is no looping over optimizers
|
||||
result = self._run_optimization(batch_idx, split_batch)
|
||||
if result:
|
||||
self.batch_outputs[0].append(result.training_step_output)
|
||||
self.batch_outputs[0].append(copy(result.training_step_output))
|
||||
|
||||
def teardown(self) -> None:
|
||||
# release memory
|
||||
|
|
|
@ -108,7 +108,7 @@ def test__training_step__epoch_end__flow_dict(tmpdir):
|
|||
acc = acc + batch_idx
|
||||
|
||||
self.training_step_called = True
|
||||
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
|
||||
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
|
||||
return out
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
|
@ -116,11 +116,13 @@ def test__training_step__epoch_end__flow_dict(tmpdir):
|
|||
|
||||
# verify we saw the current num of batches
|
||||
assert len(outputs) == 2
|
||||
assert len({id(output) for output in outputs}) == 2
|
||||
assert [output["batch_idx"] for output in outputs] == [0, 1]
|
||||
|
||||
for b in outputs:
|
||||
assert isinstance(b, dict)
|
||||
assert self.count_num_graphs(b) == 0
|
||||
assert {"random_things", "loss"} == set(b.keys())
|
||||
assert {"random_things", "loss", "batch_idx"} == set(b.keys())
|
||||
|
||||
def backward(self, loss, optimizer, optimizer_idx):
|
||||
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
|
||||
|
@ -155,7 +157,7 @@ def test__training_step__step_end__epoch_end__flow_dict(tmpdir):
|
|||
acc = acc + batch_idx
|
||||
|
||||
self.training_step_called = True
|
||||
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
|
||||
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
|
||||
return self.out
|
||||
|
||||
def training_step_end(self, tr_step_output):
|
||||
|
@ -169,11 +171,13 @@ def test__training_step__step_end__epoch_end__flow_dict(tmpdir):
|
|||
|
||||
# verify we saw the current num of batches
|
||||
assert len(outputs) == 2
|
||||
assert len({id(output) for output in outputs}) == 2
|
||||
assert [output["batch_idx"] for output in outputs] == [0, 1]
|
||||
|
||||
for b in outputs:
|
||||
assert isinstance(b, dict)
|
||||
assert self.count_num_graphs(b) == 0
|
||||
assert {"random_things", "loss"} == set(b.keys())
|
||||
assert {"random_things", "loss", "batch_idx"} == set(b.keys())
|
||||
|
||||
def backward(self, loss, optimizer, optimizer_idx):
|
||||
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
@ -1885,13 +1886,22 @@ def test_multiple_trainer_constant_memory_allocated(tmpdir):
|
|||
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
|
||||
assert trainer.callback_metrics["train_loss"].device == torch.device("cpu")
|
||||
|
||||
# before measuring the memory force release any leftover allocations, including CUDA tensors
|
||||
gc.collect()
|
||||
memory_1 = torch.cuda.memory_allocated(0)
|
||||
assert memory_1 == initial
|
||||
|
||||
deepcopy(trainer)
|
||||
|
||||
# before measuring the memory force release any leftover allocations, including CUDA tensors
|
||||
gc.collect()
|
||||
memory_2 = torch.cuda.memory_allocated(0)
|
||||
assert memory_1 == memory_2 == initial
|
||||
assert memory_2 == initial
|
||||
|
||||
trainer_2 = Trainer(**trainer_kwargs)
|
||||
trainer_2.fit(model)
|
||||
memory_3 = torch.cuda.memory_allocated(0)
|
||||
|
||||
assert initial == memory_1 == memory_3
|
||||
# before measuring the memory force release any leftover allocations, including CUDA tensors
|
||||
gc.collect()
|
||||
memory_3 = torch.cuda.memory_allocated(0)
|
||||
assert memory_3 == initial
|
||||
|
|
Loading…
Reference in New Issue