diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 0d5e47d72e..3e9a2d622f 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -236,24 +236,11 @@ class _MultiProcessingLauncher(_Launcher): process this output. """ - - def tensor_to_bytes(tensor: Tensor) -> bytes: - buffer = io.BytesIO() - torch.save(tensor.cpu().clone(), buffer) - return buffer.getvalue() - + callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu()) # send tensors as bytes to avoid issues with memory sharing - # print("callback metrics trainer", trainer.callback_metrics) - callback_metrics = {'foo': torch.tensor(3.), 'foo_2': torch.tensor(4.), 'foo_3': torch.tensor(2.), 'foo_4': torch.tensor(1.5000), 'foo_5': torch.tensor(3.), - 'foo_11': torch.tensor(1.5000), 'foo_11_step': torch.tensor(2.5000), 'bar': torch.tensor(6.), 'bar_2': torch.tensor(1.), - 'bar_3': torch.tensor(3.), 'foo_6': torch.tensor(9.), 'foo_7': torch.tensor(12.), 'foo_8': torch.tensor(2.), 'foo_9': torch.tensor(1.5000), - 'foo_10': torch.tensor(2.), 'foo_11_epoch': torch.tensor(1.5000)} - buffer = io.BytesIO() torch.save(callback_metrics, buffer) - # callback_metrics = apply_to_collection(callback_metrics, Tensor, tensor_to_bytes) - # print("callback metrics bytes", callback_metrics) - return {"callback_metrics": buffer.getvalue()} + return {"callback_metrics_bytes": buffer.getvalue()} def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we @@ -265,14 +252,10 @@ class _MultiProcessingLauncher(_Launcher): on the current trainer. """ - - def bytes_to_tensor(tensor_bytes: bytes) -> Tensor: - return torch.load(io.BytesIO(tensor_bytes)) - # NOTE: `get_extra_results` needs to be called before - callback_metrics = extra["callback_metrics"] - print("received callback metrics bytes", callback_metrics) - trainer.callback_metrics.update(torch.load(io.BytesIO(callback_metrics))) # apply_to_collection(callback_metrics, bytes, bytes_to_tensor)) + callback_metrics_bytes = extra["callback_metrics_bytes"] + callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes)) + trainer.callback_metrics.update(callback_metrics) @override def kill(self, signum: _SIGNUM) -> None: