Remove dead code in eval loop output tracking (#8625)

This commit is contained in:
Adrian Wälchli 2021-07-30 14:04:51 +02:00 committed by GitHub
parent bb4887368c
commit 1bc052c290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 9 deletions

View File

@ -240,16 +240,10 @@ class EvaluationEpochLoop(Loop):
return step_kwargs return step_kwargs
def _track_output_for_epoch_end( def _track_output_for_epoch_end(
self, self, outputs: List[STEP_OUTPUT], output: Optional[STEP_OUTPUT]
outputs: List[Union[ResultCollection, Dict, Tensor]], ) -> List[STEP_OUTPUT]:
output: Optional[Union[ResultCollection, Dict, Tensor]],
) -> List[Union[ResultCollection, Dict, Tensor]]:
if output is not None: if output is not None:
if isinstance(output, ResultCollection): if isinstance(output, dict):
output = output.detach()
if self.trainer.move_metrics_to_cpu:
output = output.cpu()
elif isinstance(output, dict):
output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu)
elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu: elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu:
output = output.cpu() output = output.cpu()