This commit is contained in:
awaelchli 2024-06-23 18:35:29 +02:00
parent 2db8e089dd
commit dc18138032
1 changed files with 5 additions and 22 deletions

View File

@ -236,24 +236,11 @@ class _MultiProcessingLauncher(_Launcher):
process this output.
"""
def tensor_to_bytes(tensor: Tensor) -> bytes:
buffer = io.BytesIO()
torch.save(tensor.cpu().clone(), buffer)
return buffer.getvalue()
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
# send tensors as bytes to avoid issues with memory sharing
# print("callback metrics trainer", trainer.callback_metrics)
callback_metrics = {'foo': torch.tensor(3.), 'foo_2': torch.tensor(4.), 'foo_3': torch.tensor(2.), 'foo_4': torch.tensor(1.5000), 'foo_5': torch.tensor(3.),
'foo_11': torch.tensor(1.5000), 'foo_11_step': torch.tensor(2.5000), 'bar': torch.tensor(6.), 'bar_2': torch.tensor(1.),
'bar_3': torch.tensor(3.), 'foo_6': torch.tensor(9.), 'foo_7': torch.tensor(12.), 'foo_8': torch.tensor(2.), 'foo_9': torch.tensor(1.5000),
'foo_10': torch.tensor(2.), 'foo_11_epoch': torch.tensor(1.5000)}
buffer = io.BytesIO()
torch.save(callback_metrics, buffer)
# callback_metrics = apply_to_collection(callback_metrics, Tensor, tensor_to_bytes)
# print("callback metrics bytes", callback_metrics)
return {"callback_metrics": buffer.getvalue()}
return {"callback_metrics_bytes": buffer.getvalue()}
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
@ -265,14 +252,10 @@ class _MultiProcessingLauncher(_Launcher):
on the current trainer.
"""
def bytes_to_tensor(tensor_bytes: bytes) -> Tensor:
return torch.load(io.BytesIO(tensor_bytes))
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
print("received callback metrics bytes", callback_metrics)
trainer.callback_metrics.update(torch.load(io.BytesIO(callback_metrics))) # apply_to_collection(callback_metrics, bytes, bytes_to_tensor))
callback_metrics_bytes = extra["callback_metrics_bytes"]
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
trainer.callback_metrics.update(callback_metrics)
@override
def kill(self, signum: _SIGNUM) -> None: