diff --git a/README.md b/README.md index 919bb72a3d..e85e73098a 100644 --- a/README.md +++ b/README.md @@ -38,13 +38,20 @@ To use lightning do 2 things: - Automatic training loop ```python # define what happens for training here -def training_step(self, data_batch, batch_nb): +def training_step(self, data_batch, batch_nb): + x, y = data_batch + out = self.forward(x) + loss = my_loss(out, y) + return {'loss': loss} ``` - Automatic validation loop ```python # define what happens for validation here -def validation_step(self, data_batch, batch_nb): +def validation_step(self, data_batch, batch_nb): x, y = data_batch + out = self.forward(x) + loss = my_loss(out, y) + return {'loss': loss} ``` - Automatic early stopping