fix ONNX model save on GPU (#3145)

* added to(device)

* added test

* fix test on gpu

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* remove multi gpu check

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* updated message

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* updated test

* onxx to onnx

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update tests/models/test_onnx.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* add no grad

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* add isinstance back

* chlog

* error is input_sample is not Tensor

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
Lezwon Castelino 2020-08-26 21:52:19 +05:30 committed by GitHub
parent bd35c869ee
commit d9ea25590e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 11 deletions

View File

@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188))
- Fixed ONNX model save on GPU ([#3145](https://github.com/PyTorchLightning/pytorch-lightning/pull/3145))
## [0.9.0] - YYYY-MM-DD
### Added

View File

@ -1716,11 +1716,16 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')
if input_sample is not None:
raise ValueError(f'Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`')
else:
raise ValueError('Could not export to ONNX since neither `input_sample` nor'
' `model.example_input_array` attribute is set.')
input_data = input_data.to(self.device)
if 'example_outputs' not in kwargs:
self.eval()
kwargs['example_outputs'] = self(input_data)
with torch.no_grad():
kwargs['example_outputs'] = self(input_data)
torch.onnx.export(self, input_data, file_path, **kwargs)

View File

@ -17,7 +17,21 @@ def test_model_saves_with_input_sample(tmpdir):
trainer = Trainer(max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_model_saves_on_gpu(tmpdir):
"""Test that model saves on gpu"""
model = EvalModelTemplate()
trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
@ -30,7 +44,7 @@ def test_model_saves_with_example_output(tmpdir):
trainer = Trainer(max_epochs=1)
trainer.fit(model)
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
model.eval()
example_outputs = model.forward(input_sample)
@ -41,7 +55,7 @@ def test_model_saves_with_example_output(tmpdir):
def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06
@ -66,7 +80,7 @@ def test_model_saves_on_multi_gpu(tmpdir):
tpipes.run_model_test(trainer_options, model)
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
@ -74,7 +88,7 @@ def test_model_saves_on_multi_gpu(tmpdir):
def test_verbose_param(tmpdir, capsys):
"""Test that output is present when verbose parameter is set"""
model = EvalModelTemplate()
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
assert "graph(%" in captured.out
@ -84,11 +98,23 @@ def test_error_if_no_input(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onxx")
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
file_path = os.path.join(tmpdir, "model.onnx")
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
r' `model.example_input_array` attribute is set.'):
model.to_onnx(file_path)
def test_error_if_input_sample_is_not_tensor(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = np.random.randn(1, 28 * 28)
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
f'`Tensor`'):
model.to_onnx(file_path, input_sample)
def test_if_inference_output_is_valid(tmpdir):
"""Test that the output inferred from ONNX model is same as from PyTorch"""
model = EvalModelTemplate()
@ -99,7 +125,7 @@ def test_if_inference_output_is_valid(tmpdir):
with torch.no_grad():
torch_out = model(model.example_input_array)
file_path = os.path.join(tmpdir, "model.onxx")
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, model.example_input_array, export_params=True)
ort_session = onnxruntime.InferenceSession(file_path)