Update README.md
This commit is contained in:
parent
3ab8120f27
commit
614d84e560
22
README.md
22
README.md
|
@ -36,7 +36,7 @@ To use lightning do 2 things:
|
|||
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py).
|
||||
|
||||
## What does lightning control for me?
|
||||
Everything! Except the following three things:
|
||||
Everything! Except you define your data and what happens inside the training and validation loop.
|
||||
|
||||
**What happens in the training loop**
|
||||
|
||||
|
@ -46,8 +46,22 @@ def training_step(self, data_batch, batch_nb):
|
|||
x, y = data_batch
|
||||
|
||||
# define your own forward and loss calculation
|
||||
out = self.forward(x)
|
||||
loss = my_loss(out, y)
|
||||
hidden_states = self.encoder(x)
|
||||
|
||||
# even as complex as a seq-2seq + attn model
|
||||
# (this is just a toy, non-working example to illustrate)
|
||||
start_token = '<SOS>'
|
||||
last_hidden = torch.zeros(...)
|
||||
loss = 0
|
||||
for step in range(max_seq_len):
|
||||
attn_context = self.attention_nn(hidden_states, start_token)
|
||||
pred = self.decoder(start_token, attn_context, last_hidden)
|
||||
last_hidden = pred
|
||||
pred = self.predict_nn(pred)
|
||||
loss += self.loss(last_hidden, y[step])
|
||||
|
||||
#toy example as well
|
||||
loss = loss / max_seq_len
|
||||
return {'loss': loss}
|
||||
```
|
||||
|
||||
|
@ -58,7 +72,7 @@ def training_step(self, data_batch, batch_nb):
|
|||
def validation_step(self, data_batch, batch_nb):
|
||||
x, y = data_batch
|
||||
|
||||
# define your own forward and loss calculation
|
||||
# or as basic as a CNN classification
|
||||
out = self.forward(x)
|
||||
loss = my_loss(out, y)
|
||||
return {'loss': loss}
|
||||
|
|
Loading…
Reference in New Issue