20 lines
554 B
Python
20 lines
554 B
Python
|
import pickle
|
||
|
|
||
|
import cloudpickle
|
||
|
import pytest
|
||
|
|
||
|
from tests.base.datasets import MNIST, TrialMNIST, AverageDataset
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize('dataset_cls', [MNIST, TrialMNIST, AverageDataset])
|
||
|
def test_pickling_dataset_mnist(tmpdir, dataset_cls):
|
||
|
mnist = dataset_cls()
|
||
|
|
||
|
mnist_pickled = pickle.dumps(mnist)
|
||
|
mnist_loaded = pickle.loads(mnist_pickled)
|
||
|
# assert vars(mnist) == vars(mnist_loaded)
|
||
|
|
||
|
mnist_pickled = cloudpickle.dumps(mnist)
|
||
|
mnist_loaded = cloudpickle.loads(mnist_pickled)
|
||
|
# assert vars(mnist) == vars(mnist_loaded)
|