189 lines
5.0 KiB
Python
189 lines
5.0 KiB
Python
import torch
|
|
from torchmetrics import Metric, MetricCollection
|
|
|
|
from pytorch_lightning import Trainer
|
|
from tests.helpers.boring_model import BoringModel
|
|
|
|
|
|
class SumMetric(Metric):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
|
|
|
def update(self, x):
|
|
self.x += x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
|
|
class DiffMetric(Metric):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum")
|
|
|
|
def update(self, x):
|
|
self.x -= x
|
|
|
|
def compute(self):
|
|
return self.x
|
|
|
|
|
|
def test_metric_lightning(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.metric = SumMetric()
|
|
self.sum = 0.0
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x = batch
|
|
self.metric(x.sum())
|
|
self.sum += x.sum()
|
|
|
|
return self.step(x)
|
|
|
|
def training_epoch_end(self, outs):
|
|
assert torch.allclose(self.sum, self.metric.compute())
|
|
self.sum = 0.0
|
|
self.metric.reset()
|
|
|
|
model = TestModel()
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=2,
|
|
limit_val_batches=2,
|
|
max_epochs=2,
|
|
log_every_n_steps=1,
|
|
weights_summary=None,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
|
|
def test_metric_lightning_log(tmpdir):
|
|
""" Test logging a metric object and that the metric state gets reset after each epoch."""
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.metric_step = SumMetric()
|
|
self.metric_epoch = SumMetric()
|
|
self.sum = 0.0
|
|
|
|
def on_epoch_start(self):
|
|
self.sum = 0.0
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x = batch
|
|
self.metric_step(x.sum())
|
|
self.sum += x.sum()
|
|
self.log("sum_step", self.metric_step, on_epoch=True, on_step=False)
|
|
return {'loss': self.step(x), 'data': x}
|
|
|
|
def training_epoch_end(self, outs):
|
|
self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum()))
|
|
|
|
model = TestModel()
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=2,
|
|
limit_val_batches=2,
|
|
max_epochs=2,
|
|
log_every_n_steps=1,
|
|
weights_summary=None,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
logged = trainer.logged_metrics
|
|
assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
|
|
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
|
|
|
|
|
|
def test_scriptable(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# the metric is not used in the module's `forward`
|
|
# so the module should be exportable to TorchScript
|
|
self.metric = SumMetric()
|
|
self.sum = 0.0
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x = batch
|
|
self.metric(x.sum())
|
|
self.sum += x.sum()
|
|
self.log("sum", self.metric, on_epoch=True, on_step=False)
|
|
return self.step(x)
|
|
|
|
model = TestModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=2,
|
|
limit_val_batches=2,
|
|
max_epochs=1,
|
|
log_every_n_steps=1,
|
|
weights_summary=None,
|
|
logger=False,
|
|
checkpoint_callback=False,
|
|
)
|
|
trainer.fit(model)
|
|
rand_input = torch.randn(10, 32)
|
|
|
|
script_model = model.to_torchscript()
|
|
|
|
# test that we can still do inference
|
|
output = model(rand_input)
|
|
script_output = script_model(rand_input)
|
|
assert torch.allclose(output, script_output)
|
|
|
|
|
|
def test_metric_collection_lightning_log(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.metric = MetricCollection([SumMetric(), DiffMetric()])
|
|
self.sum = 0.0
|
|
self.diff = 0.0
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x = batch
|
|
metric_vals = self.metric(x.sum())
|
|
self.sum += x.sum()
|
|
self.diff -= x.sum()
|
|
self.log_dict({f'{k}_step': v for k, v in metric_vals.items()})
|
|
return self.step(x)
|
|
|
|
def training_epoch_end(self, outputs):
|
|
metric_vals = self.metric.compute()
|
|
self.log_dict({f'{k}_epoch': v for k, v in metric_vals.items()})
|
|
|
|
model = TestModel()
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=2,
|
|
limit_val_batches=2,
|
|
max_epochs=1,
|
|
log_every_n_steps=1,
|
|
weights_summary=None,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
logged = trainer.logged_metrics
|
|
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
|
|
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)
|