fixes metrics pickle issue (#3921)
Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
parent
db0e295f67
commit
4cd14c4237
|
@ -218,3 +218,13 @@ class Metric(nn.Module, ABC):
|
||||||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||||
else:
|
else:
|
||||||
setattr(self, attr, deepcopy(default))
|
setattr(self, attr, deepcopy(default))
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# ignore update and compute functions for pickling
|
||||||
|
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# manually restore update and compute functions for pickling
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.update = self._wrap_update(self.update)
|
||||||
|
self.compute = self._wrap_compute(self.compute)
|
||||||
|
|
|
@ -4,6 +4,9 @@ from pytorch_lightning.metrics.metric import Metric
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,3 +109,30 @@ def test_compute():
|
||||||
# called without update, should return cached value
|
# called without update, should return cached value
|
||||||
a._computed = 5
|
a._computed = 5
|
||||||
assert a.compute() == 5
|
assert a.compute() == 5
|
||||||
|
|
||||||
|
|
||||||
|
class ToPickle(Dummy):
|
||||||
|
def update(self, x):
|
||||||
|
self.x += x
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickle(tmpdir):
|
||||||
|
# doesn't tests for DDP
|
||||||
|
a = ToPickle()
|
||||||
|
a.update(1)
|
||||||
|
|
||||||
|
metric_pickled = pickle.dumps(a)
|
||||||
|
metric_loaded = pickle.loads(metric_pickled)
|
||||||
|
|
||||||
|
assert metric_loaded.compute() == 1
|
||||||
|
|
||||||
|
metric_loaded.update(5)
|
||||||
|
assert metric_loaded.compute() == 5
|
||||||
|
|
||||||
|
metric_pickled = cloudpickle.dumps(a)
|
||||||
|
metric_loaded = cloudpickle.loads(metric_pickled)
|
||||||
|
|
||||||
|
assert metric_loaded.compute() == 1
|
||||||
|
|
|
@ -3,6 +3,7 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
|
import pickle
|
||||||
|
|
||||||
NUM_PROCESSES = 2
|
NUM_PROCESSES = 2
|
||||||
NUM_BATCHES = 10
|
NUM_BATCHES = 10
|
||||||
|
@ -19,6 +20,10 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste
|
||||||
|
|
||||||
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
|
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
|
||||||
|
|
||||||
|
# verify metrics work after being loaded from pickled state
|
||||||
|
pickled_metric = pickle.dumps(metric)
|
||||||
|
metric = pickle.loads(pickled_metric)
|
||||||
|
|
||||||
# Only use ddp if world size
|
# Only use ddp if world size
|
||||||
if worldsize > 1:
|
if worldsize > 1:
|
||||||
setup_ddp(rank, worldsize)
|
setup_ddp(rank, worldsize)
|
||||||
|
|
Loading…
Reference in New Issue