fixes metrics pickle issue (#3921)

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
Ananya Harsh Jha 2020-10-06 20:33:57 -04:00 committed by GitHub
parent db0e295f67
commit 4cd14c4237
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 0 deletions

View File

@ -218,3 +218,13 @@ class Metric(nn.Module, ABC):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))
def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
def __setstate__(self, state):
# manually restore update and compute functions for pickling
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

View File

@ -4,6 +4,9 @@ from pytorch_lightning.metrics.metric import Metric
import os
import numpy as np
import pickle
import cloudpickle
torch.manual_seed(42)
@ -106,3 +109,30 @@ def test_compute():
# called without update, should return cached value
a._computed = 5
assert a.compute() == 5
class ToPickle(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
def test_pickle(tmpdir):
# doesn't tests for DDP
a = ToPickle()
a.update(1)
metric_pickled = pickle.dumps(a)
metric_loaded = pickle.loads(metric_pickled)
assert metric_loaded.compute() == 1
metric_loaded.update(5)
assert metric_loaded.compute() == 5
metric_pickled = cloudpickle.dumps(a)
metric_loaded = cloudpickle.loads(metric_pickled)
assert metric_loaded.compute() == 1

View File

@ -3,6 +3,7 @@ import numpy as np
import os
import sys
import pytest
import pickle
NUM_PROCESSES = 2
NUM_BATCHES = 10
@ -19,6 +20,10 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)
# Only use ddp if world size
if worldsize > 1:
setup_ddp(rank, worldsize)