Rename variables (#124)
- data_batch → batch - batch_i → batch_idx - dataloader_i → dataloader_idx - tng → training - training_dataloader → train_dataloader - add_log_row_interval → row_log_interval - gradient_clip → gradient_clip_val - prog → progress - tqdm_dic → tqdm_dict
This commit is contained in:
parent
3d16a686b3
commit
b0a0a47a0b
20
README.md
20
README.md
|
@ -110,7 +110,7 @@ class CoolSystem(pl.LightningModule):
|
|||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
# REQUIRED
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
|
@ -177,16 +177,16 @@ You define the blue parts using the LightningModule interface:
|
|||
|
||||
```python
|
||||
# what to do in the training loop
|
||||
def training_step(self, data_batch, batch_nb):
|
||||
def training_step(self, batch, batch_nb):
|
||||
|
||||
# what to do in the validation loop
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
def validation_step(self, batch, batch_nb):
|
||||
|
||||
# how to aggregate validation_step outputs
|
||||
def validation_end(self, outputs):
|
||||
|
||||
# and your dataloaders
|
||||
def tng_dataloader():
|
||||
def train_dataloader():
|
||||
def val_dataloader():
|
||||
def test_dataloader():
|
||||
```
|
||||
|
@ -195,8 +195,8 @@ def test_dataloader():
|
|||
|
||||
```python
|
||||
# define what happens for training here
|
||||
def training_step(self, data_batch, batch_nb):
|
||||
x, y = data_batch
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
||||
# define your own forward and loss calculation
|
||||
hidden_states = self.encoder(x)
|
||||
|
@ -222,8 +222,8 @@ def training_step(self, data_batch, batch_nb):
|
|||
|
||||
```python
|
||||
# define what happens for validation here
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
x, y = data_batch
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
||||
# or as basic as a CNN classification
|
||||
out = self.forward(x)
|
||||
|
@ -248,8 +248,8 @@ def validation_end(self, outputs):
|
|||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
```
|
||||
|
||||
## Tensorboard
|
||||
|
|
|
@ -10,7 +10,7 @@ Otherwise, to Define a Lightning Module, implement the following methods:
|
|||
**Required**:
|
||||
|
||||
- [training_step](RequiredTrainerInterface.md#training_step)
|
||||
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
|
||||
- [train_dataloader](RequiredTrainerInterface.md#train_dataloader)
|
||||
- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)
|
||||
|
||||
**Optional**:
|
||||
|
@ -23,7 +23,7 @@ Otherwise, to Define a Lightning Module, implement the following methods:
|
|||
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
|
||||
- [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint)
|
||||
- [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint)
|
||||
- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics)
|
||||
- [update_training_log_metrics](RequiredTrainerInterface.md#update_training_log_metrics)
|
||||
- [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args)
|
||||
|
||||
---
|
||||
|
@ -81,7 +81,7 @@ class CoolModel(pl.LightningModule):
|
|||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
|
@ -111,7 +111,7 @@ The LightningModule interface is on the right. Each method corresponds to a part
|
|||
### training_step
|
||||
|
||||
``` {.python}
|
||||
def training_step(self, data_batch, batch_nb)
|
||||
def training_step(self, batch, batch_nb)
|
||||
```
|
||||
|
||||
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
|
||||
|
@ -120,7 +120,7 @@ In this step you'd normally do the forward pass and calculate the loss for a bat
|
|||
|
||||
| Param | description |
|
||||
|---|---|
|
||||
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch_nb | Integer displaying which batch this is |
|
||||
|
||||
**Return**
|
||||
|
@ -130,14 +130,14 @@ Dictionary or OrderedDict
|
|||
| key | value | is required |
|
||||
|---|---|---|
|
||||
| loss | tensor scalar | Y |
|
||||
| prog | Dict for progress bar display. Must have only tensors | N |
|
||||
| progress | Dict for progress bar display. Must have only tensors | N |
|
||||
|
||||
|
||||
**Example**
|
||||
|
||||
``` {.python}
|
||||
def training_step(self, data_batch, batch_nb):
|
||||
x, y, z = data_batch
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y, z = batch
|
||||
|
||||
# implement your own
|
||||
out = self.forward(x)
|
||||
|
@ -145,7 +145,7 @@ def training_step(self, data_batch, batch_nb):
|
|||
|
||||
output = {
|
||||
'loss': loss, # required
|
||||
'prog': {'tng_loss': loss, 'batch_nb': batch_nb} # optional
|
||||
'progress': {'training_loss': loss, 'batch_nb': batch_nb} # optional
|
||||
}
|
||||
|
||||
# return a dict
|
||||
|
@ -155,7 +155,7 @@ def training_step(self, data_batch, batch_nb):
|
|||
If you define multiple optimizers, this step will also be called with an additional ```optimizer_idx``` param.
|
||||
``` {.python}
|
||||
# Multiple optimizers (ie: GANs)
|
||||
def training_step(self, data_batch, batch_nb, optimizer_idx):
|
||||
def training_step(self, batch, batch_nb, optimizer_idx):
|
||||
if optimizer_idx == 0:
|
||||
# do training_step with encoder
|
||||
if optimizer_idx == 1:
|
||||
|
@ -163,11 +163,11 @@ def training_step(self, data_batch, batch_nb, optimizer_idx):
|
|||
```
|
||||
|
||||
---
|
||||
### tng_dataloader
|
||||
### train_dataloader
|
||||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self)
|
||||
def train_dataloader(self)
|
||||
```
|
||||
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
|
||||
|
||||
|
@ -178,7 +178,7 @@ PyTorch DataLoader
|
|||
|
||||
``` {.python}
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
|
||||
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
|
@ -240,10 +240,10 @@ the [optimizer_step](https://williamfalcon.github.io/pytorch-lightning/Trainer/h
|
|||
|
||||
``` {.python}
|
||||
# if you have one val dataloader:
|
||||
def validation_step(self, data_batch, batch_nb)
|
||||
def validation_step(self, batch, batch_nb)
|
||||
|
||||
# if you have multiple val dataloaders:
|
||||
def validation_step(self, data_batch, batch_nb, dataloader_idx)
|
||||
def validation_step(self, batch, batch_nb, dataloader_idxdx)
|
||||
```
|
||||
**OPTIONAL**
|
||||
If you don't need to validate you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
|
||||
|
@ -256,9 +256,9 @@ The dict you return here will be available in the `validation_end` method.
|
|||
|
||||
| Param | description |
|
||||
|---|---|
|
||||
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch_nb | Integer displaying which batch this is |
|
||||
| dataloader_i | Integer displaying which dataloader this is (only if multiple val datasets used) |
|
||||
| dataloader_idx | Integer displaying which dataloader this is (only if multiple val datasets used) |
|
||||
|
||||
**Return**
|
||||
|
||||
|
@ -270,8 +270,8 @@ The dict you return here will be available in the `validation_end` method.
|
|||
|
||||
``` {.python}
|
||||
# CASE 1: A single validation dataset
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
x, y = data_batch
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
||||
# implement your own
|
||||
out = self.forward(x)
|
||||
|
@ -302,7 +302,7 @@ If you pass in multiple validation datasets, validation_step will have an additi
|
|||
|
||||
```python
|
||||
# CASE 2: multiple validation datasets
|
||||
def validation_step(self, data_batch, batch_nb, dataset_idx):
|
||||
def validation_step(self, batch, batch_nb, dataset_idx):
|
||||
# dataset_idx tells you which dataset this is.
|
||||
```
|
||||
|
||||
|
@ -351,8 +351,8 @@ def validation_end(self, outputs):
|
|||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
```
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
|
@ -377,18 +377,18 @@ def validation_end(self, outputs):
|
|||
|
||||
val_loss_mean /= i
|
||||
val_acc_mean /= i
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
```
|
||||
|
||||
### test_step
|
||||
|
||||
``` {.python}
|
||||
# if you have one test dataloader:
|
||||
def test_step(self, data_batch, batch_nb)
|
||||
def test_step(self, batch, batch_nb)
|
||||
|
||||
# if you have multiple test dataloaders:
|
||||
def test_step(self, data_batch, batch_nb, dataloader_idx)
|
||||
def test_step(self, batch, batch_nb, dataloader_idxdx)
|
||||
```
|
||||
**OPTIONAL**
|
||||
If you don't need to test you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
|
||||
|
@ -403,9 +403,9 @@ This function is used when you execute `trainer.test()`.
|
|||
|
||||
| Param | description |
|
||||
|---|---|
|
||||
| data_batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch | The output of your dataloader. A tensor, tuple or list |
|
||||
| batch_nb | Integer displaying which batch this is |
|
||||
| dataloader_i | Integer displaying which dataloader this is (only if multiple test datasets used) |
|
||||
| dataloader_idx | Integer displaying which dataloader this is (only if multiple test datasets used) |
|
||||
|
||||
**Return**
|
||||
|
||||
|
@ -417,8 +417,8 @@ This function is used when you execute `trainer.test()`.
|
|||
|
||||
``` {.python}
|
||||
# CASE 1: A single test dataset
|
||||
def test_step(self, data_batch, batch_nb):
|
||||
x, y = data_batch
|
||||
def test_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
||||
# implement your own
|
||||
out = self.forward(x)
|
||||
|
@ -443,7 +443,7 @@ If you pass in multiple test datasets, test_step will have an additional argumen
|
|||
|
||||
```python
|
||||
# CASE 2: multiple test datasets
|
||||
def test_step(self, data_batch, batch_nb, dataset_idx):
|
||||
def test_step(self, batch, batch_nb, dataset_idx):
|
||||
# dataset_idx tells you which dataset this is.
|
||||
```
|
||||
|
||||
|
@ -490,8 +490,8 @@ def test_end(self, outputs):
|
|||
|
||||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
```
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
|
@ -516,8 +516,8 @@ def test_end(self, outputs):
|
|||
|
||||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -633,10 +633,10 @@ def test_dataloader(self):
|
|||
```
|
||||
|
||||
---
|
||||
### update_tng_log_metrics
|
||||
### update_training_log_metrics
|
||||
|
||||
``` {.python}
|
||||
def update_tng_log_metrics(self, logs)
|
||||
def update_training_log_metrics(self, logs)
|
||||
```
|
||||
Called by lightning right before it logs metrics for this batch.
|
||||
This is a chance to amend or add to the metrics about to be logged.
|
||||
|
@ -647,7 +647,7 @@ Dict
|
|||
**Example**
|
||||
|
||||
``` {.python}
|
||||
def update_tng_log_metrics(self, logs):
|
||||
def update_training_log_metrics(self, logs):
|
||||
# modify or add to logs
|
||||
return logs
|
||||
```
|
||||
|
@ -674,7 +674,7 @@ def add_model_specific_args(parent_parser, root_dir):
|
|||
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
||||
|
||||
# param overwrites
|
||||
# parser.set_defaults(gradient_clip=5.0)
|
||||
# parser.set_defaults(gradient_clip_val=5.0)
|
||||
|
||||
# network params
|
||||
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
|
||||
|
|
|
@ -22,7 +22,7 @@ self.experiment.add_scalars(...)
|
|||
Total training batches seen across all epochs
|
||||
|
||||
---
|
||||
#### gradient_clip
|
||||
#### gradient_clip_val
|
||||
The current gradient clip value
|
||||
|
||||
---
|
||||
|
|
|
@ -13,7 +13,7 @@ trainer = Trainer(show_progress_bar=True)
|
|||
Every k batches lightning will make an entry in the metrics log
|
||||
``` {.python}
|
||||
# DEFAULT (ie: save a .csv log file every 10 batches)
|
||||
trainer = Trainer(add_log_row_interval=10)
|
||||
trainer = Trainer(row_log_interval=10)
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -52,10 +52,10 @@ Specifically, this will [clip the gradient norm computed over all model paramete
|
|||
|
||||
``` {.python}
|
||||
# DEFAULT (ie: don't clip)
|
||||
trainer = Trainer(gradient_clip=0)
|
||||
trainer = Trainer(gradient_clip_val=0)
|
||||
|
||||
# clip gradients with norm above 0.5
|
||||
trainer = Trainer(gradient_clip=0.5)
|
||||
trainer = Trainer(gradient_clip_val=0.5)
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -58,12 +58,12 @@ def on_post_performance_check(self):
|
|||
```
|
||||
|
||||
---
|
||||
#### on_tng_metrics
|
||||
#### on_training_metrics
|
||||
Called in the training loop, right before metrics are logged.
|
||||
Although you can log at any time by using self.experiment, you can use
|
||||
this callback to modify what will be logged.
|
||||
```python
|
||||
def on_tng_metrics(self, metrics):
|
||||
def on_training_metrics(self, metrics):
|
||||
# do something before validation end
|
||||
```
|
||||
|
||||
|
|
|
@ -79,14 +79,14 @@ class LightningTemplateModel(LightningModule):
|
|||
nll = F.nll_loss(logits, labels)
|
||||
return nll
|
||||
|
||||
def training_step(self, data_batch, batch_i):
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
Lightning calls this inside the training loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
# forward pass
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = self.forward(x)
|
||||
|
@ -105,13 +105,13 @@ class LightningTemplateModel(LightningModule):
|
|||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
|
@ -167,8 +167,8 @@ class LightningTemplateModel(LightningModule):
|
|||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
|
||||
return tqdm_dict
|
||||
|
||||
# ---------------------
|
||||
# TRAINING SETUP
|
||||
|
@ -208,8 +208,8 @@ class LightningTemplateModel(LightningModule):
|
|||
return loader
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
print('tng data loader called')
|
||||
def train_dataloader(self):
|
||||
print('training data loader called')
|
||||
return self.__dataloader(train=True)
|
||||
|
||||
@pl.data_loader
|
||||
|
@ -233,7 +233,7 @@ class LightningTemplateModel(LightningModule):
|
|||
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
||||
|
||||
# param overwrites
|
||||
# parser.set_defaults(gradient_clip=5.0)
|
||||
# parser.set_defaults(gradient_clip_val=5.0)
|
||||
|
||||
# network params
|
||||
parser.add_argument('--in_features', default=28 * 28, type=int)
|
||||
|
|
|
@ -146,7 +146,7 @@ class GAN(pl.LightningModule):
|
|||
return [opt_g, opt_d], []
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])])
|
||||
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
||||
|
|
|
@ -10,7 +10,7 @@ class ModelHooks(torch.nn.Module):
|
|||
"""
|
||||
pass
|
||||
|
||||
def on_batch_start(self, data_batch):
|
||||
def on_batch_start(self, batch):
|
||||
pass
|
||||
|
||||
def on_batch_end(self):
|
||||
|
@ -28,7 +28,7 @@ class ModelHooks(torch.nn.Module):
|
|||
def on_post_performance_check(self):
|
||||
pass
|
||||
|
||||
def on_tng_metrics(self, metrics):
|
||||
def on_training_metrics(self, metrics):
|
||||
pass
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
|
|
|
@ -106,7 +106,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
optimizer.zero_grad()
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
Implement a PyTorch DataLoader
|
||||
:return:
|
||||
|
|
|
@ -23,5 +23,5 @@ class LightningTestModel(LightningValidationMixin, LightningTestMixin, Lightning
|
|||
Most common test case. Validation and test dataloaders
|
||||
"""
|
||||
|
||||
def on_tng_metrics(self, logs):
|
||||
def on_training_metrics(self, logs):
|
||||
logs['some_tensor_to_test'] = torch.rand(1)
|
||||
|
|
|
@ -81,14 +81,14 @@ class LightningTestModelBase(LightningModule):
|
|||
nll = F.nll_loss(logits, labels)
|
||||
return nll
|
||||
|
||||
def training_step(self, data_batch, batch_i):
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
Lightning calls this inside the training loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
# forward pass
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = self.forward(x)
|
||||
|
@ -104,7 +104,7 @@ class LightningTestModelBase(LightningModule):
|
|||
if self.trainer.batch_nb % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
'prog': {'some_val': loss_val * loss_val}
|
||||
'progress': {'some_val': loss_val * loss_val}
|
||||
})
|
||||
return output
|
||||
if self.trainer.batch_nb % 2 == 0:
|
||||
|
@ -153,7 +153,7 @@ class LightningTestModelBase(LightningModule):
|
|||
return loader
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
return self._dataloader(train=True)
|
||||
|
||||
@staticmethod
|
||||
|
@ -167,7 +167,7 @@ class LightningTestModelBase(LightningModule):
|
|||
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
||||
|
||||
# param overwrites
|
||||
# parser.set_defaults(gradient_clip=5.0)
|
||||
# parser.set_defaults(gradient_clip_val=5.0)
|
||||
|
||||
# network params
|
||||
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
|
||||
|
|
|
@ -25,13 +25,13 @@ class LightningValidationStepMixin:
|
|||
def val_dataloader(self):
|
||||
return self._dataloader(train=False)
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
|
@ -51,16 +51,16 @@ class LightningValidationStepMixin:
|
|||
val_acc = val_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_i % 1 == 0:
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
return output
|
||||
if batch_i % 2 == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
return val_acc
|
||||
|
||||
if batch_i % 3 == 0:
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
|
@ -104,8 +104,8 @@ class LightningValidationMixin(LightningValidationStepMixin):
|
|||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
|
||||
class LightningValidationStepMultipleDataloadersMixin:
|
||||
|
@ -118,13 +118,13 @@ class LightningValidationStepMultipleDataloadersMixin:
|
|||
def val_dataloader(self):
|
||||
return [self._dataloader(train=False), self._dataloader(train=False)]
|
||||
|
||||
def validation_step(self, data_batch, batch_i, dataloader_i):
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
|
@ -144,26 +144,26 @@ class LightningValidationStepMultipleDataloadersMixin:
|
|||
val_acc = val_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_i % 1 == 0:
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
return output
|
||||
if batch_i % 2 == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
return val_acc
|
||||
|
||||
if batch_i % 3 == 0:
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
'test_dic': {'val_loss_a': loss_val}
|
||||
})
|
||||
return output
|
||||
if batch_i % 5 == 0:
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'val_loss_{dataloader_i}': loss_val,
|
||||
f'val_acc_{dataloader_i}': val_acc,
|
||||
f'val_loss_{dataloader_idx}': loss_val,
|
||||
f'val_acc_{dataloader_idx}': val_acc,
|
||||
})
|
||||
return output
|
||||
|
||||
|
@ -206,8 +206,8 @@ class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipl
|
|||
val_loss_mean /= i
|
||||
val_acc_mean /= i
|
||||
|
||||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
|
||||
class LightningTestStepMixin:
|
||||
|
@ -216,13 +216,13 @@ class LightningTestStepMixin:
|
|||
def test_dataloader(self):
|
||||
return self._dataloader(train=False)
|
||||
|
||||
def test_step(self, data_batch, batch_i):
|
||||
def test_step(self, batch, batch_idx):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
|
@ -242,16 +242,16 @@ class LightningTestStepMixin:
|
|||
test_acc = test_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_i % 1 == 0:
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
return output
|
||||
if batch_i % 2 == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_i % 3 == 0:
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
|
@ -290,8 +290,8 @@ class LightningTestMixin(LightningTestStepMixin):
|
|||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
|
||||
class LightningTestStepMultipleDataloadersMixin:
|
||||
|
@ -300,13 +300,13 @@ class LightningTestStepMultipleDataloadersMixin:
|
|||
def test_dataloader(self):
|
||||
return [self._dataloader(train=False), self._dataloader(train=False)]
|
||||
|
||||
def test_step(self, data_batch, batch_i, dataloader_i):
|
||||
def test_step(self, batch, batch_idx, dataloader_idx):
|
||||
"""
|
||||
Lightning calls this inside the validation loop
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
|
@ -326,26 +326,26 @@ class LightningTestStepMultipleDataloadersMixin:
|
|||
test_acc = test_acc.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if batch_i % 1 == 0:
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
return output
|
||||
if batch_i % 2 == 0:
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_i % 3 == 0:
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
'test_dic': {'test_loss_a': loss_test}
|
||||
})
|
||||
return output
|
||||
if batch_i % 5 == 0:
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'test_loss_{dataloader_i}': loss_test,
|
||||
f'test_acc_{dataloader_i}': test_acc,
|
||||
f'test_loss_{dataloader_idx}': loss_test,
|
||||
f'test_acc_{dataloader_idx}': test_acc,
|
||||
})
|
||||
return output
|
||||
|
||||
|
@ -383,5 +383,5 @@ class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloaders
|
|||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
|
||||
tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
|
|
@ -57,7 +57,7 @@ class Trainer(TrainerIO):
|
|||
experiment=None,
|
||||
early_stop_callback=None,
|
||||
checkpoint_callback=None,
|
||||
gradient_clip=0,
|
||||
gradient_clip_val=0,
|
||||
process_position=0,
|
||||
nb_gpu_nodes=1,
|
||||
gpus=None,
|
||||
|
@ -75,7 +75,7 @@ class Trainer(TrainerIO):
|
|||
test_percent_check=1.0,
|
||||
val_check_interval=1.0,
|
||||
log_save_interval=100,
|
||||
add_log_row_interval=10,
|
||||
row_log_interval=10,
|
||||
distributed_backend=None,
|
||||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
|
@ -88,7 +88,7 @@ class Trainer(TrainerIO):
|
|||
:param experiment: Test-tube experiment
|
||||
:param early_stop_callback: Callback for early stopping
|
||||
:param checkpoint_callback: Callback for checkpointing
|
||||
:param gradient_clip: int. 0 means don't clip.
|
||||
:param gradient_clip_val: int. 0 means don't clip.
|
||||
:param process_position: shown in the tqdm bar
|
||||
:param nb_gpu_nodes: number of GPU nodes
|
||||
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
|
||||
|
@ -106,7 +106,7 @@ class Trainer(TrainerIO):
|
|||
:param test_percent_check: int. How much of test set to check
|
||||
:param val_check_interval: int. Check val this frequently within a train epoch
|
||||
:param log_save_interval: int. Writes logs to disk this often
|
||||
:param add_log_row_interval: int. How often to add logging rows
|
||||
:param row_log_interval: int. How often to add logging rows
|
||||
:param distributed_backend: str. dp, or ddp.
|
||||
:param use_amp: Bool. If true uses apex for 16bit precision
|
||||
:param print_nan_grads: Bool. Prints nan gradients
|
||||
|
@ -118,7 +118,7 @@ class Trainer(TrainerIO):
|
|||
# Transfer params
|
||||
self.nb_gpu_nodes = nb_gpu_nodes
|
||||
self.log_gpu_memory = log_gpu_memory
|
||||
self.gradient_clip = gradient_clip
|
||||
self.gradient_clip_val = gradient_clip_val
|
||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.enable_early_stop = early_stop_callback is not None
|
||||
self.track_grad_norm = track_grad_norm
|
||||
|
@ -138,9 +138,9 @@ class Trainer(TrainerIO):
|
|||
self.batch_nb = 0
|
||||
self.tqdm_metrics = {}
|
||||
self.nb_val_batches = 0
|
||||
self.nb_tng_batches = 0
|
||||
self.nb_training_batches = 0
|
||||
self.nb_test_batches = 0
|
||||
self.tng_dataloader = None
|
||||
self.train_dataloader = None
|
||||
self.test_dataloader = None
|
||||
self.val_dataloader = None
|
||||
|
||||
|
@ -188,13 +188,13 @@ class Trainer(TrainerIO):
|
|||
self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
||||
|
||||
# can't init progress bar here because starting a new process
|
||||
# means the prog_bar won't survive pickling
|
||||
# means the progress_bar won't survive pickling
|
||||
self.show_progress_bar = show_progress_bar
|
||||
|
||||
# logging
|
||||
self.log_save_interval = log_save_interval
|
||||
self.val_check_interval = val_check_interval
|
||||
self.add_log_row_interval = add_log_row_interval
|
||||
self.row_log_interval = row_log_interval
|
||||
|
||||
# how much of the data to use
|
||||
self.__determine_data_use_amount(train_percent_check, val_percent_check,
|
||||
|
@ -405,36 +405,36 @@ class Trainer(TrainerIO):
|
|||
return is_overriden
|
||||
|
||||
@property
|
||||
def __tng_tqdm_dic(self):
|
||||
tqdm_dic = {
|
||||
def __training_tqdm_dict(self):
|
||||
tqdm_dict = {
|
||||
'loss': '{0:.3f}'.format(self.avg_loss),
|
||||
'epoch': '{}'.format(self.current_epoch),
|
||||
'batch_nb': '{}'.format(self.batch_nb),
|
||||
}
|
||||
|
||||
if self.experiment is not None:
|
||||
tqdm_dic['v_nb'] = self.experiment.version
|
||||
tqdm_dict['v_nb'] = self.experiment.version
|
||||
|
||||
tqdm_dic.update(self.tqdm_metrics)
|
||||
tqdm_dict.update(self.tqdm_metrics)
|
||||
|
||||
if self.on_gpu:
|
||||
tqdm_dic['gpu'] = '{}'.format(torch.cuda.current_device())
|
||||
tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device())
|
||||
|
||||
return tqdm_dic
|
||||
return tqdm_dict
|
||||
|
||||
@property
|
||||
def tng_tqdm_dic(self):
|
||||
def training_tqdm_dict(self):
|
||||
"""
|
||||
Read-only for tqdm metrics
|
||||
:return:
|
||||
"""
|
||||
return self.__tng_tqdm_dic
|
||||
return self.__training_tqdm_dict
|
||||
|
||||
def __layout_bookeeping(self):
|
||||
|
||||
# determine number of training batches
|
||||
self.nb_tng_batches = len(self.tng_dataloader)
|
||||
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
|
||||
self.nb_training_batches = len(self.train_dataloader)
|
||||
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
|
||||
|
||||
# determine number of validation batches
|
||||
# val datasets could be none, 1 or 2+
|
||||
|
@ -450,7 +450,7 @@ class Trainer(TrainerIO):
|
|||
self.nb_test_batches = max(1, self.nb_test_batches)
|
||||
|
||||
# determine when to check validation
|
||||
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
|
||||
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
def __add_tqdm_metrics(self, metrics):
|
||||
|
@ -460,15 +460,15 @@ class Trainer(TrainerIO):
|
|||
|
||||
self.tqdm_metrics[k] = v
|
||||
|
||||
def __evaluation_forward(self, model, data_batch, batch_i, dataloader_i, test=False):
|
||||
# make dataloader_i arg in validation_step optional
|
||||
args = [data_batch, batch_i]
|
||||
def __evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
|
||||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
if test and len(self.test_dataloader) > 1:
|
||||
args.append(dataloader_i)
|
||||
args.append(dataloader_idx)
|
||||
|
||||
elif not test and len(self.val_dataloader) > 1:
|
||||
args.append(dataloader_i)
|
||||
args.append(dataloader_idx)
|
||||
|
||||
# handle DP, DDP forward
|
||||
if self.use_ddp or self.use_dp:
|
||||
|
@ -481,8 +481,8 @@ class Trainer(TrainerIO):
|
|||
root_gpu = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
root_gpu = self.data_parallel_device_ids[0]
|
||||
data_batch = self.transfer_batch_to_gpu(data_batch, root_gpu)
|
||||
args[0] = data_batch
|
||||
batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
||||
args[0] = batch
|
||||
|
||||
if test:
|
||||
output = model.test_step(*args)
|
||||
|
@ -497,7 +497,7 @@ class Trainer(TrainerIO):
|
|||
:param model: PT model
|
||||
:param dataloaders: list of PT dataloaders
|
||||
:param max_batches: Scalar
|
||||
:param dataloader_i:
|
||||
:param dataloader_idx:
|
||||
:param test: boolean
|
||||
:return:
|
||||
"""
|
||||
|
@ -512,21 +512,21 @@ class Trainer(TrainerIO):
|
|||
outputs = []
|
||||
|
||||
# run training
|
||||
for dataloader_i, dl in enumerate(dataloaders):
|
||||
for dataloader_idx, dl in enumerate(dataloaders):
|
||||
dl_outputs = []
|
||||
for batch_i, data_batch in enumerate(dl):
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
|
||||
if data_batch is None: # pragma: no cover
|
||||
if batch is None: # pragma: no cover
|
||||
continue
|
||||
|
||||
# stop short when on fast_dev_run (sets max_batch=1)
|
||||
if batch_i >= max_batches:
|
||||
if batch_idx >= max_batches:
|
||||
break
|
||||
|
||||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
output = self.__evaluation_forward(model, data_batch, batch_i, dataloader_i,
|
||||
output = self.__evaluation_forward(model, batch, batch_idx, dataloader_idx,
|
||||
test)
|
||||
|
||||
# track outputs for collation
|
||||
|
@ -563,7 +563,7 @@ class Trainer(TrainerIO):
|
|||
:return:
|
||||
"""
|
||||
|
||||
self.tng_dataloader = model.tng_dataloader
|
||||
self.train_dataloader = model.train_dataloader
|
||||
self.test_dataloader = model.test_dataloader
|
||||
self.val_dataloader = model.val_dataloader
|
||||
|
||||
|
@ -576,7 +576,7 @@ class Trainer(TrainerIO):
|
|||
if have_val_loaders and not isinstance(self.val_dataloader, list):
|
||||
self.val_dataloader = [self.val_dataloader]
|
||||
|
||||
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
|
||||
if self.use_ddp and not isinstance(self.train_dataloader.sampler, DistributedSampler):
|
||||
msg = """
|
||||
You're using multiple gpus and multiple nodes without using a DistributedSampler
|
||||
to assign a subset of your data to each process. To silence this warning, pass a
|
||||
|
@ -767,7 +767,7 @@ class Trainer(TrainerIO):
|
|||
except Exception:
|
||||
self.node_rank = 0
|
||||
|
||||
# show progbar only on prog_rank 0
|
||||
# show progressbar only on progress_rank 0
|
||||
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
|
||||
# determine which process we are and world size
|
||||
|
@ -929,7 +929,7 @@ class Trainer(TrainerIO):
|
|||
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
if self.use_ddp:
|
||||
self.tng_dataloader.sampler.set_epoch(epoch_nb)
|
||||
self.train_dataloader.sampler.set_epoch(epoch_nb)
|
||||
|
||||
# get model
|
||||
model = self.__get_model()
|
||||
|
@ -937,7 +937,7 @@ class Trainer(TrainerIO):
|
|||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch_nb
|
||||
self.current_epoch = epoch_nb
|
||||
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
||||
self.total_batches = self.nb_training_batches + self.nb_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
|
||||
# init progress_bar when requested
|
||||
|
@ -950,7 +950,7 @@ class Trainer(TrainerIO):
|
|||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
# -----------------
|
||||
self.run_tng_epoch()
|
||||
self.run_training_epoch()
|
||||
|
||||
# update LR schedulers
|
||||
if self.lr_schedulers is not None:
|
||||
|
@ -961,20 +961,20 @@ class Trainer(TrainerIO):
|
|||
met_min_epochs = epoch_nb > self.min_nb_epochs
|
||||
if self.enable_early_stop and met_min_epochs:
|
||||
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
|
||||
logs=self.__tng_tqdm_dic)
|
||||
logs=self.__training_tqdm_dict)
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
return
|
||||
|
||||
def run_tng_epoch(self):
|
||||
def run_training_epoch(self):
|
||||
# before epoch hook
|
||||
if self.__is_function_implemented('on_epoch_start'):
|
||||
model = self.__get_model()
|
||||
model.on_epoch_start()
|
||||
|
||||
# run epoch
|
||||
for batch_nb, data_batch in enumerate(self.tng_dataloader):
|
||||
for batch_nb, batch in enumerate(self.train_dataloader):
|
||||
self.batch_nb = batch_nb
|
||||
self.global_step += 1
|
||||
|
||||
|
@ -984,14 +984,14 @@ class Trainer(TrainerIO):
|
|||
# stop when the flag is changed or we've gone past the amount
|
||||
# requested in the batches
|
||||
self.total_batch_nb += 1
|
||||
met_batch_limit = batch_nb > self.nb_tng_batches
|
||||
met_batch_limit = batch_nb > self.nb_training_batches
|
||||
if met_batch_limit:
|
||||
break
|
||||
|
||||
# ---------------
|
||||
# RUN TRAIN STEP
|
||||
# ---------------
|
||||
batch_result = self.__run_tng_batch(data_batch, batch_nb)
|
||||
batch_result = self.__run_training_batch(batch, batch_nb)
|
||||
early_stop_epoch = batch_result == -1
|
||||
|
||||
# ---------------
|
||||
|
@ -1009,12 +1009,12 @@ class Trainer(TrainerIO):
|
|||
self.experiment.save()
|
||||
|
||||
# when metrics should be logged
|
||||
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
|
||||
if batch_nb % self.row_log_interval == 0 or early_stop_epoch:
|
||||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
model = self.__get_model()
|
||||
metrics = self.__tng_tqdm_dic
|
||||
metrics = self.__training_tqdm_dict
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu and self.log_gpu_memory:
|
||||
|
@ -1027,8 +1027,8 @@ class Trainer(TrainerIO):
|
|||
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
if self.__is_function_implemented('on_tng_metrics'):
|
||||
model.on_tng_metrics(metrics)
|
||||
if self.__is_function_implemented('on_training_metrics'):
|
||||
model.on_training_metrics(metrics)
|
||||
|
||||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(
|
||||
|
@ -1103,10 +1103,10 @@ class Trainer(TrainerIO):
|
|||
# nothing matches, return the value as is without transform
|
||||
return batch
|
||||
|
||||
def __tng_forward(self, data_batch, batch_nb, opt_idx):
|
||||
def __training_forward(self, batch, batch_nb, opt_idx):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
:param data_batch:
|
||||
:param batch:
|
||||
:param batch_nb:
|
||||
:return:
|
||||
"""
|
||||
|
@ -1114,7 +1114,7 @@ class Trainer(TrainerIO):
|
|||
# FORWARD
|
||||
# ---------------
|
||||
# enable not needing to add opt_idx to training_step
|
||||
args = [data_batch, batch_nb]
|
||||
args = [batch, batch_nb]
|
||||
if len(self.optimizers) > 1:
|
||||
args.append(opt_idx)
|
||||
|
||||
|
@ -1126,8 +1126,8 @@ class Trainer(TrainerIO):
|
|||
gpu_id = 0
|
||||
if type(self.data_parallel_device_ids) is list:
|
||||
gpu_id = self.data_parallel_device_ids[0]
|
||||
data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id)
|
||||
args[0] = data_batch
|
||||
batch = self.transfer_batch_to_gpu(batch, gpu_id)
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
else:
|
||||
|
@ -1137,14 +1137,14 @@ class Trainer(TrainerIO):
|
|||
# TQDM metrics
|
||||
# ---------------
|
||||
try:
|
||||
prog_output = output['prog']
|
||||
progress_output = output['progress']
|
||||
|
||||
# reduce prog metrics for tqdm when using dp
|
||||
# reduce progress metrics for tqdm when using dp
|
||||
if self.use_dp:
|
||||
nb_gpus = self.num_gpus
|
||||
prog_output = reduce_distributed_output(prog_output, nb_gpus)
|
||||
progress_output = reduce_distributed_output(progress_output, nb_gpus)
|
||||
|
||||
model_specific_tqdm_metrics_dic = prog_output
|
||||
model_specific_tqdm_metrics_dic = progress_output
|
||||
except Exception:
|
||||
model_specific_tqdm_metrics_dic = {}
|
||||
|
||||
|
@ -1166,9 +1166,9 @@ class Trainer(TrainerIO):
|
|||
return loss, model_specific_tqdm_metrics_dic
|
||||
|
||||
def __clip_gradients(self):
|
||||
if self.gradient_clip > 0:
|
||||
if self.gradient_clip_val > 0:
|
||||
model = self.__get_model()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
|
||||
|
||||
def __print_nan_grads(self):
|
||||
model = self.__get_model()
|
||||
|
@ -1176,14 +1176,14 @@ class Trainer(TrainerIO):
|
|||
if torch.isnan(param.grad.float()).any():
|
||||
print(param, param.grad)
|
||||
|
||||
def __run_tng_batch(self, data_batch, batch_nb):
|
||||
if data_batch is None:
|
||||
def __run_training_batch(self, batch, batch_nb):
|
||||
if batch is None:
|
||||
return 0
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_start'):
|
||||
model_ref = self.__get_model()
|
||||
response = model_ref.on_batch_start(data_batch)
|
||||
response = model_ref.on_batch_start(batch)
|
||||
|
||||
if response == -1:
|
||||
return -1
|
||||
|
@ -1195,7 +1195,7 @@ class Trainer(TrainerIO):
|
|||
for opt_idx, optimizer in enumerate(self.optimizers):
|
||||
|
||||
# forward pass
|
||||
loss, model_specific_tqdm_metrics = self.__tng_forward(data_batch, batch_nb, opt_idx)
|
||||
loss, model_specific_tqdm_metrics = self.__training_forward(batch, batch_nb, opt_idx)
|
||||
|
||||
# track metrics
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics)
|
||||
|
@ -1238,10 +1238,10 @@ class Trainer(TrainerIO):
|
|||
self.batch_loss_value = 0
|
||||
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||
|
||||
# update progbar
|
||||
# update progressbar
|
||||
if self.show_progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
tqdm_metrics = self.__training_tqdm_dict
|
||||
self.progress_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# activate batch end hook
|
||||
|
@ -1296,11 +1296,11 @@ class Trainer(TrainerIO):
|
|||
|
||||
if self.show_progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
tqdm_metrics = self.__training_tqdm_dict
|
||||
self.progress_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# model checkpointing
|
||||
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
|
||||
print('save callback...')
|
||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
|
||||
logs=self.__tng_tqdm_dic)
|
||||
logs=self.__training_tqdm_dict)
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
|
||||
def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None):
|
||||
|
||||
# tng, test, val check intervals
|
||||
# training, test, val check intervals
|
||||
parser.add_argument('--eval_test_set', dest='eval_test_set', action='store_true',
|
||||
help='true = run test set also')
|
||||
parser.add_argument('--check_val_every_n_epoch', default=1, type=int,
|
||||
|
@ -19,7 +19,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
|
|||
parser.add_argument('--max_nb_epochs', default=200, type=int, help='cap epochs')
|
||||
parser.add_argument('--min_nb_epochs', default=2, type=int, help='min epochs')
|
||||
parser.add_argument('--train_percent_check', default=1.0, type=float,
|
||||
help='how much of tng set to check')
|
||||
help='how much of training set to check')
|
||||
parser.add_argument('--val_percent_check', default=1.0, type=float,
|
||||
help='how much of val set to check')
|
||||
parser.add_argument('--test_percent_check', default=1.0, type=float,
|
||||
|
@ -29,7 +29,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
|
|||
help='how much within 1 epoch to check val')
|
||||
parser.add_argument('--log_save_interval', default=100, type=int,
|
||||
help='how many batches between log saves')
|
||||
parser.add_argument('--add_log_row_interval', default=100, type=int,
|
||||
parser.add_argument('--row_log_interval', default=100, type=int,
|
||||
help='add log every k batches')
|
||||
|
||||
# early stopping
|
||||
|
@ -40,7 +40,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
|
|||
help='number of epochs until stop')
|
||||
|
||||
# gradient handling
|
||||
parser.add_argument('--gradient_clip', default=-1, type=int)
|
||||
parser.add_argument('--gradient_clip_val', default=-1, type=int)
|
||||
parser.add_argument('--track_grad_norm', default=-1, type=int,
|
||||
help='if > 0, will track this grad norm')
|
||||
|
||||
|
@ -78,9 +78,9 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
|
|||
# FAST training
|
||||
# use these settings to make sure network has no bugs without running a full dataset
|
||||
parser.add_argument('--fast_dev_run', dest='fast_dev_run', default=False, action='store_true',
|
||||
help='runs validation after 1 tng step')
|
||||
help='runs validation after 1 training step')
|
||||
parser.add_argument('--enable_tqdm', dest='enable_tqdm', default=False, action='store_true',
|
||||
help='false removes the prog bar')
|
||||
help='false removes the progress bar')
|
||||
parser.add_argument('--overfit', default=-1, type=float,
|
||||
help='% of dataset to use with this option. float, or -1 for none')
|
||||
|
||||
|
@ -93,7 +93,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
|
|||
parser.add_argument('--debug', dest='debug', action='store_true',
|
||||
help='enables/disables test tube')
|
||||
parser.add_argument('--local', dest='local', action='store_true',
|
||||
help='enables local tng')
|
||||
help='enables local training')
|
||||
|
||||
# optimizer
|
||||
parser.add_argument('--lr_scheduler_milestones', default=None, type=str)
|
||||
|
|
|
@ -32,7 +32,7 @@ class CoolModel(pl.LightningModule):
|
|||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||
return {'training_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
@ -47,7 +47,7 @@ class CoolModel(pl.LightningModule):
|
|||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
|
@ -182,13 +182,13 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
|||
|
||||
def assert_ok_val_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.tng_tqdm_dic['val_acc']
|
||||
acc = trainer.training_tqdm_dict['val_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
def assert_ok_test_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.tng_tqdm_dic['test_acc']
|
||||
acc = trainer.training_tqdm_dict['test_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
|
|
|
@ -511,7 +511,7 @@ def test_early_stopping_cpu_model():
|
|||
stopping = EarlyStopping(monitor='val_loss')
|
||||
trainer_options = dict(
|
||||
early_stop_callback=stopping,
|
||||
gradient_clip=1.0,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_pct=0.20,
|
||||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
|
@ -1098,7 +1098,7 @@ def test_all_features_cpu_model():
|
|||
"""
|
||||
|
||||
trainer_options = dict(
|
||||
gradient_clip=1.0,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_pct=0.20,
|
||||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
|
@ -1256,7 +1256,7 @@ def test_multiple_val_dataloader():
|
|||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# verify tng completed
|
||||
# verify training completed
|
||||
assert result == 1
|
||||
|
||||
# verify there are 2 val loaders
|
||||
|
@ -1450,13 +1450,13 @@ def run_prediction(dataloader, trained_model, dp=False):
|
|||
|
||||
def assert_ok_val_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.tng_tqdm_dic['val_acc']
|
||||
acc = trainer.training_tqdm_dict['val_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
def assert_ok_test_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.tng_tqdm_dic['test_acc']
|
||||
acc = trainer.training_tqdm_dict['test_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue