This commit is contained in:
parent
2db8e089dd
commit
dc18138032
|
@ -236,24 +236,11 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
process this output.
|
||||
|
||||
"""
|
||||
|
||||
def tensor_to_bytes(tensor: Tensor) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor.cpu().clone(), buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
|
||||
# send tensors as bytes to avoid issues with memory sharing
|
||||
# print("callback metrics trainer", trainer.callback_metrics)
|
||||
callback_metrics = {'foo': torch.tensor(3.), 'foo_2': torch.tensor(4.), 'foo_3': torch.tensor(2.), 'foo_4': torch.tensor(1.5000), 'foo_5': torch.tensor(3.),
|
||||
'foo_11': torch.tensor(1.5000), 'foo_11_step': torch.tensor(2.5000), 'bar': torch.tensor(6.), 'bar_2': torch.tensor(1.),
|
||||
'bar_3': torch.tensor(3.), 'foo_6': torch.tensor(9.), 'foo_7': torch.tensor(12.), 'foo_8': torch.tensor(2.), 'foo_9': torch.tensor(1.5000),
|
||||
'foo_10': torch.tensor(2.), 'foo_11_epoch': torch.tensor(1.5000)}
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(callback_metrics, buffer)
|
||||
# callback_metrics = apply_to_collection(callback_metrics, Tensor, tensor_to_bytes)
|
||||
# print("callback metrics bytes", callback_metrics)
|
||||
return {"callback_metrics": buffer.getvalue()}
|
||||
return {"callback_metrics_bytes": buffer.getvalue()}
|
||||
|
||||
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
|
||||
|
@ -265,14 +252,10 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
on the current trainer.
|
||||
|
||||
"""
|
||||
|
||||
def bytes_to_tensor(tensor_bytes: bytes) -> Tensor:
|
||||
return torch.load(io.BytesIO(tensor_bytes))
|
||||
|
||||
# NOTE: `get_extra_results` needs to be called before
|
||||
callback_metrics = extra["callback_metrics"]
|
||||
print("received callback metrics bytes", callback_metrics)
|
||||
trainer.callback_metrics.update(torch.load(io.BytesIO(callback_metrics))) # apply_to_collection(callback_metrics, bytes, bytes_to_tensor))
|
||||
callback_metrics_bytes = extra["callback_metrics_bytes"]
|
||||
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
|
||||
trainer.callback_metrics.update(callback_metrics)
|
||||
|
||||
@override
|
||||
def kill(self, signum: _SIGNUM) -> None:
|
||||
|
|
Loading…
Reference in New Issue