parent
de15759f76
commit
16a7326e52
|
@ -70,6 +70,7 @@ jobs:
|
|||
run: |
|
||||
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
|
||||
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
|
||||
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
|
||||
|
||||
# Note: This uses an internal pip API and may not always work
|
||||
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
|
||||
|
|
|
@ -3,6 +3,7 @@ import logging as log
|
|||
import os
|
||||
import pickle
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -273,3 +274,4 @@ def test_model_saving_loading(tmpdir):
|
|||
def test_model_pickle(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
pickle.dumps(model)
|
||||
cloudpickle.dumps(model)
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
# extended list of dependencies dor development and run lint and tests
|
||||
-r ./requirements.txt
|
||||
|
||||
cloudpickle
|
||||
cloudpickle>=1.2
|
|
@ -1,9 +1,11 @@
|
|||
import glob
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import types
|
||||
from argparse import Namespace
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -671,10 +673,12 @@ def test_gradient_clipping(tmpdir):
|
|||
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
|
||||
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
|
||||
|
||||
trainer = Trainer(max_steps=1,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
default_root_dir=tmpdir)
|
||||
trainer = Trainer(
|
||||
max_steps=1,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
default_root_dir=tmpdir
|
||||
)
|
||||
|
||||
# for the test
|
||||
model.optimizer_step = _optimizer_step
|
||||
|
@ -824,3 +828,12 @@ def test_trainer_subclassing():
|
|||
# when we pass in an unknown arg, the base class should complain
|
||||
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'abcdefg'"):
|
||||
TrainerSubclass(abcdefg='unknown_arg')
|
||||
|
||||
|
||||
def test_trainer_pickle(tmpdir):
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir
|
||||
)
|
||||
pickle.dumps(trainer)
|
||||
cloudpickle.dumps(trainer)
|
||||
|
|
Loading…
Reference in New Issue