import torch from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric, MetricCollection from tests.helpers.boring_model import BoringModel class SumMetric(Metric): def __init__(self): super().__init__() self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") def update(self, x): self.x += x def compute(self): return self.x class DiffMetric(Metric): def __init__(self): super().__init__() self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") def update(self, x): self.x -= x def compute(self): return self.x def test_metric_lightning(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() self.metric = SumMetric() self.sum = 0.0 def training_step(self, batch, batch_idx): x = batch self.metric(x.sum()) self.sum += x.sum() return self.step(x) def training_epoch_end(self, outs): assert torch.allclose(self.sum, self.metric.compute()) self.sum = 0.0 self.metric.reset() model = TestModel() model.val_dataloader = None trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, log_every_n_steps=1, weights_summary=None, ) trainer.fit(model) def test_metric_lightning_log(tmpdir): """ Test logging a metric object and that the metric state gets reset after each epoch.""" class TestModel(BoringModel): def __init__(self): super().__init__() self.metric_step = SumMetric() self.metric_epoch = SumMetric() self.sum = 0.0 def on_epoch_start(self): self.sum = 0.0 def training_step(self, batch, batch_idx): x = batch self.metric_step(x.sum()) self.sum += x.sum() self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) return {'loss': self.step(x), 'data': x} def training_epoch_end(self, outs): self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) model = TestModel() model.val_dataloader = None trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, log_every_n_steps=1, weights_summary=None, ) trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) def test_scriptable(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() # the metric is not used in the module's `forward` # so the module should be exportable to TorchScript self.metric = SumMetric() self.sum = 0.0 def training_step(self, batch, batch_idx): x = batch self.metric(x.sum()) self.sum += x.sum() self.log("sum", self.metric, on_epoch=True, on_step=False) return self.step(x) model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1, log_every_n_steps=1, weights_summary=None, logger=False, checkpoint_callback=False, ) trainer.fit(model) rand_input = torch.randn(10, 32) script_model = model.to_torchscript() # test that we can still do inference output = model(rand_input) script_output = script_model(rand_input) assert torch.allclose(output, script_output) def test_metric_collection_lightning_log(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() self.metric = MetricCollection([SumMetric(), DiffMetric()]) self.sum = 0.0 self.diff = 0.0 def training_step(self, batch, batch_idx): x = batch metric_vals = self.metric(x.sum()) self.sum += x.sum() self.diff -= x.sum() self.log_dict({f'{k}_step': v for k, v in metric_vals.items()}) return self.step(x) def training_epoch_end(self, outputs): metric_vals = self.metric.compute() self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()}) model = TestModel() model.val_dataloader = None trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1, log_every_n_steps=1, weights_summary=None, ) trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum) assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)