test cloudpickle (#2105)

* cloudpickle

* ci tests
This commit is contained in:
Jirka Borovec 2020-06-09 22:51:30 +02:00 committed by GitHub
parent de15759f76
commit 16a7326e52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 5 deletions

View File

@ -70,6 +70,7 @@ jobs:
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow

View File

@ -3,6 +3,7 @@ import logging as log
import os
import pickle
import cloudpickle
import pytest
import torch
@ -273,3 +274,4 @@ def test_model_saving_loading(tmpdir):
def test_model_pickle(tmpdir):
model = EvalModelTemplate()
pickle.dumps(model)
cloudpickle.dumps(model)

View File

@ -4,4 +4,4 @@
# extended list of dependencies dor development and run lint and tests
-r ./requirements.txt
cloudpickle
cloudpickle>=1.2

View File

@ -1,9 +1,11 @@
import glob
import math
import os
import pickle
import types
from argparse import Namespace
import cloudpickle
import pytest
import torch
@ -671,10 +673,12 @@ def test_gradient_clipping(tmpdir):
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
trainer = Trainer(max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir)
trainer = Trainer(
max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir
)
# for the test
model.optimizer_step = _optimizer_step
@ -824,3 +828,12 @@ def test_trainer_subclassing():
# when we pass in an unknown arg, the base class should complain
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'"):
TrainerSubclass(abcdefg='unknown_arg')
def test_trainer_pickle(tmpdir):
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir
)
pickle.dumps(trainer)
cloudpickle.dumps(trainer)