Convert tensors to bytes instead of numpy in multiprocessing result-queue (#20005)

This commit is contained in:
awaelchli 2024-06-23 19:36:57 +02:00 committed by GitHub
parent e330da5870
commit 9304a2c72e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 9 deletions

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import os
import queue
@ -19,7 +20,6 @@ from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union
import numpy as np
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
@ -226,7 +226,7 @@ class _MultiProcessingLauncher(_Launcher):
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
avoid issues with memory sharing, we cast the data to numpy.
avoid issues with memory sharing, we convert tensors to bytes.
Args:
trainer: reference to the Trainer.
@ -236,14 +236,15 @@ class _MultiProcessingLauncher(_Launcher):
process this output.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
return {"callback_metrics": callback_metrics}
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
buffer = io.BytesIO()
torch.save(callback_metrics, buffer)
# send tensors as bytes to avoid issues with memory sharing
return {"callback_metrics_bytes": buffer.getvalue()}
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
cast back the data to ``torch.Tensor``.
convert bytes back to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
@ -252,8 +253,9 @@ class _MultiProcessingLauncher(_Launcher):
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
callback_metrics_bytes = extra["callback_metrics_bytes"]
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
trainer.callback_metrics.update(callback_metrics)
@override
def kill(self, signum: _SIGNUM) -> None: