fixes metrics pickle issue (#3921)
Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
parent
db0e295f67
commit
4cd14c4237
|
@ -218,3 +218,13 @@ class Metric(nn.Module, ABC):
|
|||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
||||
def __getstate__(self):
|
||||
# ignore update and compute functions for pickling
|
||||
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
|
||||
|
||||
def __setstate__(self, state):
|
||||
# manually restore update and compute functions for pickling
|
||||
self.__dict__.update(state)
|
||||
self.update = self._wrap_update(self.update)
|
||||
self.compute = self._wrap_compute(self.compute)
|
||||
|
|
|
@ -4,6 +4,9 @@ from pytorch_lightning.metrics.metric import Metric
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
import pickle
|
||||
import cloudpickle
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
|
@ -106,3 +109,30 @@ def test_compute():
|
|||
# called without update, should return cached value
|
||||
a._computed = 5
|
||||
assert a.compute() == 5
|
||||
|
||||
|
||||
class ToPickle(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
|
||||
def test_pickle(tmpdir):
|
||||
# doesn't tests for DDP
|
||||
a = ToPickle()
|
||||
a.update(1)
|
||||
|
||||
metric_pickled = pickle.dumps(a)
|
||||
metric_loaded = pickle.loads(metric_pickled)
|
||||
|
||||
assert metric_loaded.compute() == 1
|
||||
|
||||
metric_loaded.update(5)
|
||||
assert metric_loaded.compute() == 5
|
||||
|
||||
metric_pickled = cloudpickle.dumps(a)
|
||||
metric_loaded = cloudpickle.loads(metric_pickled)
|
||||
|
||||
assert metric_loaded.compute() == 1
|
||||
|
|
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import pickle
|
||||
|
||||
NUM_PROCESSES = 2
|
||||
NUM_BATCHES = 10
|
||||
|
@ -19,6 +20,10 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste
|
|||
|
||||
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
|
||||
|
||||
# verify metrics work after being loaded from pickled state
|
||||
pickled_metric = pickle.dumps(metric)
|
||||
metric = pickle.loads(pickled_metric)
|
||||
|
||||
# Only use ddp if world size
|
||||
if worldsize > 1:
|
||||
setup_ddp(rank, worldsize)
|
||||
|
|
Loading…
Reference in New Issue