Convert tensors to bytes instead of numpy in multiprocessing result-queue (#20005)
This commit is contained in:
parent
e330da5870
commit
9304a2c72e
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
|
@ -19,7 +20,6 @@ from contextlib import suppress
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -226,7 +226,7 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
|
||||
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
|
||||
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
|
||||
avoid issues with memory sharing, we cast the data to numpy.
|
||||
avoid issues with memory sharing, we convert tensors to bytes.
|
||||
|
||||
Args:
|
||||
trainer: reference to the Trainer.
|
||||
|
@ -236,14 +236,15 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
process this output.
|
||||
|
||||
"""
|
||||
callback_metrics: dict = apply_to_collection(
|
||||
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
|
||||
) # send as numpy to avoid issues with memory sharing
|
||||
return {"callback_metrics": callback_metrics}
|
||||
callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
|
||||
buffer = io.BytesIO()
|
||||
torch.save(callback_metrics, buffer)
|
||||
# send tensors as bytes to avoid issues with memory sharing
|
||||
return {"callback_metrics_bytes": buffer.getvalue()}
|
||||
|
||||
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
|
||||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
|
||||
cast back the data to ``torch.Tensor``.
|
||||
convert bytes back to ``torch.Tensor``.
|
||||
|
||||
Args:
|
||||
trainer: reference to the Trainer.
|
||||
|
@ -252,8 +253,9 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
|
||||
"""
|
||||
# NOTE: `get_extra_results` needs to be called before
|
||||
callback_metrics = extra["callback_metrics"]
|
||||
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
|
||||
callback_metrics_bytes = extra["callback_metrics_bytes"]
|
||||
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
|
||||
trainer.callback_metrics.update(callback_metrics)
|
||||
|
||||
@override
|
||||
def kill(self, signum: _SIGNUM) -> None:
|
||||
|
|
Loading…
Reference in New Issue