Update README.md

This commit is contained in:
William Falcon 2019-06-29 18:05:17 -04:00 committed by GitHub
parent 3ab8120f27
commit 614d84e560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 4 deletions

View File

@ -36,7 +36,7 @@ To use lightning do 2 things:
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py).
## What does lightning control for me?
Everything! Except the following three things:
Everything! Except you define your data and what happens inside the training and validation loop.
**What happens in the training loop**
@ -46,8 +46,22 @@ def training_step(self, data_batch, batch_nb):
x, y = data_batch
# define your own forward and loss calculation
out = self.forward(x)
loss = my_loss(out, y)
hidden_states = self.encoder(x)
# even as complex as a seq-2seq + attn model
# (this is just a toy, non-working example to illustrate)
start_token = '<SOS>'
last_hidden = torch.zeros(...)
loss = 0
for step in range(max_seq_len):
attn_context = self.attention_nn(hidden_states, start_token)
pred = self.decoder(start_token, attn_context, last_hidden)
last_hidden = pred
pred = self.predict_nn(pred)
loss += self.loss(last_hidden, y[step])
#toy example as well
loss = loss / max_seq_len
return {'loss': loss}
```
@ -58,7 +72,7 @@ def training_step(self, data_batch, batch_nb):
def validation_step(self, data_batch, batch_nb):
x, y = data_batch
# define your own forward and loss calculation
# or as basic as a CNN classification
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}