updated readme
This commit is contained in:
parent
5e41159b16
commit
a59f351ef8
37
README.md
37
README.md
|
@ -93,7 +93,9 @@ class CoolSystem(pl.LightningModule):
|
|||
# REQUIRED
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'loss': F.cross_entropy(y_hat, y)}
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
tensorboard_logs = {'train_loss': loss}
|
||||
return {'loss': loss, 'log': tensorboard_logs}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
|
@ -104,7 +106,8 @@ class CoolSystem(pl.LightningModule):
|
|||
def validation_end(self, outputs):
|
||||
# OPTIONAL
|
||||
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
return {'avg_val_loss': avg_loss}
|
||||
tensorboard_logs = {'val_loss': avg_loss}
|
||||
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
|
||||
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
|
@ -138,30 +141,27 @@ trainer = Trainer()
|
|||
trainer.fit(model)
|
||||
```
|
||||
|
||||
Or with tensorboard logger and some options turned on such as multi-gpu, etc...
|
||||
Trainer sets up a tensorboard logger, early stopping and checkpointing by default (you can modify all of them or
|
||||
use something other than tensorboard).
|
||||
|
||||
Here are more advanced examples
|
||||
```python
|
||||
from test_tube import Experiment
|
||||
|
||||
# PyTorch summarywriter with a few bells and whistles
|
||||
exp = Experiment(save_dir=os.getcwd())
|
||||
|
||||
# train on cpu using only 10% of the data (for demo purposes)
|
||||
# pass in experiment for automatic tensorboard logging.
|
||||
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
|
||||
trainer = Trainer(max_nb_epochs=1, train_percent_check=0.1)
|
||||
|
||||
# train on 4 gpus (lightning chooses GPUs for you)
|
||||
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=4)
|
||||
# trainer = Trainer(max_nb_epochs=1, gpus=4)
|
||||
|
||||
# train on 4 gpus (you choose GPUs)
|
||||
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 3, 7])
|
||||
# trainer = Trainer(max_nb_epochs=1, gpus=[0, 1, 3, 7])
|
||||
|
||||
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
|
||||
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=8, nb_gpu_nodes=4)
|
||||
# trainer = Trainer(max_nb_epochs=1, gpus=8, nb_gpu_nodes=4)
|
||||
|
||||
# train (1 epoch only here for demo)
|
||||
trainer.fit(model)
|
||||
|
||||
# view tensorflow logs
|
||||
# view tensorboard logs
|
||||
print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
|
||||
print('and going to http://localhost:6006 on your browser')
|
||||
```
|
||||
|
@ -176,7 +176,7 @@ trainer.test()
|
|||
Everything in gray!
|
||||
You define the blue parts using the LightningModule interface:
|
||||
|
||||
![Ouverview](./docs/source/_static/overview_flat.jpg)
|
||||
![Overview](./docs/source/_static/overview_flat.jpg)
|
||||
|
||||
```python
|
||||
# what to do in the training loop
|
||||
|
@ -251,12 +251,13 @@ def validation_end(self, outputs):
|
|||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
logs = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
result = {'log': logs}
|
||||
return result
|
||||
```
|
||||
|
||||
## Tensorboard
|
||||
Lightning is fully integrated with tensorboard.
|
||||
Lightning is fully integrated with tensorboard, MLFlow and supports any logging module.
|
||||
|
||||
![tensorboard-support](./docs/source/_static/tf_loss.png)
|
||||
|
||||
|
|
Loading…
Reference in New Issue