This commit is contained in:
parent
2db8e089dd
commit
dc18138032
|
@ -236,24 +236,11 @@ class _MultiProcessingLauncher(_Launcher):
|
||||||
process this output.
|
process this output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
|
||||||
def tensor_to_bytes(tensor: Tensor) -> bytes:
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torch.save(tensor.cpu().clone(), buffer)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
# send tensors as bytes to avoid issues with memory sharing
|
# send tensors as bytes to avoid issues with memory sharing
|
||||||
# print("callback metrics trainer", trainer.callback_metrics)
|
|
||||||
callback_metrics = {'foo': torch.tensor(3.), 'foo_2': torch.tensor(4.), 'foo_3': torch.tensor(2.), 'foo_4': torch.tensor(1.5000), 'foo_5': torch.tensor(3.),
|
|
||||||
'foo_11': torch.tensor(1.5000), 'foo_11_step': torch.tensor(2.5000), 'bar': torch.tensor(6.), 'bar_2': torch.tensor(1.),
|
|
||||||
'bar_3': torch.tensor(3.), 'foo_6': torch.tensor(9.), 'foo_7': torch.tensor(12.), 'foo_8': torch.tensor(2.), 'foo_9': torch.tensor(1.5000),
|
|
||||||
'foo_10': torch.tensor(2.), 'foo_11_epoch': torch.tensor(1.5000)}
|
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
torch.save(callback_metrics, buffer)
|
torch.save(callback_metrics, buffer)
|
||||||
# callback_metrics = apply_to_collection(callback_metrics, Tensor, tensor_to_bytes)
|
return {"callback_metrics_bytes": buffer.getvalue()}
|
||||||
# print("callback metrics bytes", callback_metrics)
|
|
||||||
return {"callback_metrics": buffer.getvalue()}
|
|
||||||
|
|
||||||
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
|
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
|
||||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
|
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
|
||||||
|
@ -265,14 +252,10 @@ class _MultiProcessingLauncher(_Launcher):
|
||||||
on the current trainer.
|
on the current trainer.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def bytes_to_tensor(tensor_bytes: bytes) -> Tensor:
|
|
||||||
return torch.load(io.BytesIO(tensor_bytes))
|
|
||||||
|
|
||||||
# NOTE: `get_extra_results` needs to be called before
|
# NOTE: `get_extra_results` needs to be called before
|
||||||
callback_metrics = extra["callback_metrics"]
|
callback_metrics_bytes = extra["callback_metrics_bytes"]
|
||||||
print("received callback metrics bytes", callback_metrics)
|
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
|
||||||
trainer.callback_metrics.update(torch.load(io.BytesIO(callback_metrics))) # apply_to_collection(callback_metrics, bytes, bytes_to_tensor))
|
trainer.callback_metrics.update(callback_metrics)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def kill(self, signum: _SIGNUM) -> None:
|
def kill(self, signum: _SIGNUM) -> None:
|
||||||
|
|
Loading…
Reference in New Issue