diff --git a/docs/source-pytorch/common/evaluation_basic.rst b/docs/source-pytorch/common/evaluation_basic.rst index 3dc7867de1..b9cae8f826 100644 --- a/docs/source-pytorch/common/evaluation_basic.rst +++ b/docs/source-pytorch/common/evaluation_basic.rst @@ -49,7 +49,7 @@ To add a test loop, implement the **test_step** method of the LightningModule x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) - test_loss = F.mse_loss(x_hat, x) + test_loss = F.mse_loss(x_hat, y) self.log("test_loss", test_loss) ---- @@ -109,7 +109,7 @@ To add a validation loop, implement the **validation_step** method of the Lightn x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) - val_loss = F.mse_loss(x_hat, x) + val_loss = F.mse_loss(x_hat, y) self.log("val_loss", val_loss) ----