diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 93cbb8e2cc..b2d1da957b 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -1,6 +1,7 @@ import inspect from argparse import ArgumentParser, Namespace from unittest import mock +import pickle import pytest @@ -42,14 +43,14 @@ def test_add_argparse_args_redefined(cli_args): args = parser.parse_args(cli_args) + # make sure we can pickle args + pickle.dumps(args) + # Check few deprecated args are not in namespace: for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'): assert depr_name not in args trainer = Trainer.from_argparse_args(args=args) - - # make sure trainer can be pickled - import pickle pickle.dumps(trainer) assert isinstance(trainer, Trainer)