Re-enable passing BytesIO as path in `.to_onnx()` (#20172)
This commit is contained in:
parent
be0ae06596
commit
828fd99896
|
@ -17,6 +17,7 @@ import logging
|
|||
import numbers
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
IO,
|
||||
|
@ -1364,7 +1365,7 @@ class LightningModule(
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
|
||||
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
|
||||
"""Saves the model in ONNX format.
|
||||
|
||||
Args:
|
||||
|
@ -1403,7 +1404,8 @@ class LightningModule(
|
|||
input_sample = self._on_before_batch_transfer(input_sample)
|
||||
input_sample = self._apply_batch_transfer_handler(input_sample)
|
||||
|
||||
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
|
||||
file_path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
||||
self.train(mode)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import operator
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -45,6 +46,10 @@ def test_model_saves_with_input_sample(tmp_path):
|
|||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
||||
file_path = BytesIO()
|
||||
model.to_onnx(file_path=file_path, input_sample=input_sample)
|
||||
assert len(file_path.getvalue()) > 4e2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]
|
||||
|
|
Loading…
Reference in New Issue