decouple returns from each step (#307)

* decoupled training metrics from logging metrics

* decoupled validation metrics from log metrics

* updated docs

* updated docs

* updated docs

* Fixed test

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master

* merged master
This commit is contained in:
William Falcon 2019-10-05 13:35:20 -04:00 committed by GitHub
parent 8f5a06bfb8
commit 6cc3f1757f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 193 additions and 248 deletions

View File

@ -129,7 +129,8 @@ Dictionary or OrderedDict
| key | value | is required |
|---|---|---|
| loss | tensor scalar | Y |
| progress | Dict for progress bar display. Must have only tensors | N |
| progress_bar | Dict for progress bar display. Must have only tensors | N |
| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
**Example**
@ -144,7 +145,8 @@ def training_step(self, batch, batch_nb):
output = {
'loss': loss, # required
'progress': {'training_loss': loss} # optional (MUST ALL BE TENSORS)
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
'log': {'training_loss': loss} # optional (MUST ALL BE TENSORS)
}
# return a dict
@ -161,6 +163,9 @@ def training_step(self, batch, batch_nb, optimizer_idx):
# do training_step with decoder
```
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
break out of the current training epoch early.
---
### train_dataloader
@ -263,7 +268,7 @@ The dict you return here will be available in the `validation_end` method.
| Return | description | optional |
|---|---|---|
| dict | Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors. | Y |
| dict | Dict or OrderedDict - passed to the validation_end step | N |
**Example**
@ -327,9 +332,12 @@ The outputs here are strictly for the progress bar. If you don't need to display
**Return**
| Return | description | optional |
|---|---|---|
| dict | Dict of OrderedDict with metrics to display in progress bar | Y |
Dictionary or OrderedDict
| key | value | is required |
|---|---|---|
| progress_bar | Dict for progress bar display. Must have only tensors | N |
| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
**Example**
@ -351,7 +359,13 @@ def validation_end(self, outputs):
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dict
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item()}
}
return results
```
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
@ -377,7 +391,13 @@ def validation_end(self, outputs):
val_loss_mean /= i
val_acc_mean /= i
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dict
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item()}
}
return results
```
### test_step
@ -490,7 +510,13 @@ def test_end(self, outputs):
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
return tqdm_dict
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
}
return results
```
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
@ -516,7 +542,13 @@ def test_end(self, outputs):
test_loss_mean /= i
test_acc_mean /= i
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
return tqdm_dict
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
}
return results
```
---

View File

@ -58,16 +58,6 @@ def on_post_performance_check(self):
```
---
#### on_training_metrics
Called in the training loop, right before metrics are logged.
Although you can log at any time by using self.experiment, you can use
this callback to modify what will be logged.
```python
def on_training_metrics(self, metrics):
# do something before validation end
```
---
#### optimizer_step
Calls .step() and .zero_grad for each optimizer.
You can override this method to adjust how you do the optimizer step for each optimizer

View File

@ -168,7 +168,8 @@ class LightningTemplateModel(LightningModule):
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
return tqdm_dict
result = {'progress_bar': tqdm_dict}
return result
# ---------------------
# TRAINING SETUP

View File

@ -28,9 +28,6 @@ class ModelHooks(torch.nn.Module):
def on_post_performance_check(self):
pass
def on_training_metrics(self, metrics):
pass
def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad()

View File

@ -104,8 +104,9 @@ class LightningTestModelBase(LightningModule):
if self.trainer.batch_nb % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'progress': {'some_val': loss_val * loss_val}
'progress_bar': {'some_val': loss_val * loss_val}
})
return output
if self.trainer.batch_nb % 2 == 0:
return loss_val

View File

@ -105,7 +105,8 @@ class LightningValidationMixin(LightningValidationStepMixin):
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dict
results = {'progress_bar': tqdm_dict}
return results
class LightningValidationStepMultipleDataloadersMixin:
@ -207,7 +208,8 @@ class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipl
val_acc_mean /= i
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dict
result = {'progress_bar': tqdm_dict}
return result
class LightningTestStepMixin:
@ -291,7 +293,8 @@ class LightningTestMixin(LightningTestStepMixin):
test_acc_mean /= len(outputs)
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
return tqdm_dict
result = {'progress_bar': tqdm_dict}
return result
class LightningTestStepMultipleDataloadersMixin:
@ -384,4 +387,5 @@ class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloaders
test_acc_mean /= i
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
return tqdm_dict
result = {'progress_bar': tqdm_dict}
return result

View File

@ -553,7 +553,6 @@ class Trainer(TrainerIO):
:param model: PT model
:param dataloaders: list of PT dataloaders
:param max_batches: Scalar
:param dataloader_idx:
:param test: boolean
:return:
"""
@ -582,7 +581,10 @@ class Trainer(TrainerIO):
# -----------------
# RUN EVALUATION STEP
# -----------------
output = self.__evaluation_forward(model, batch, batch_idx, dataloader_idx,
output = self.__evaluation_forward(model,
batch,
batch_idx,
dataloader_idx,
test)
# track outputs for collation
@ -704,8 +706,6 @@ class Trainer(TrainerIO):
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
else:
nb_gpus = self.nb_requested_gpus
nb_tasks = self.nb_slurm_tasks
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))
# 1 gpu or dp option triggers training using DP module
@ -1054,7 +1054,8 @@ class Trainer(TrainerIO):
# ---------------
# RUN TRAIN STEP
# ---------------
batch_result, grad_norm_dic = self.__run_training_batch(batch, batch_nb)
output = self.__run_training_batch(batch, batch_nb)
batch_result, grad_norm_dic, batch_step_metrics = output
early_stop_epoch = batch_result == -1
# ---------------
@ -1073,29 +1074,9 @@ class Trainer(TrainerIO):
# when metrics should be logged
if batch_nb % self.row_log_interval == 0 or early_stop_epoch:
# count items in memory
# nb_params, nb_tensors = count_mem_items()
model = self.__get_model()
metrics = self.__training_tqdm_dict
# add gpu memory
if self.on_gpu and self.log_gpu_memory is not None:
mem_map = memory.get_memory_profile(mode=self.log_gpu_memory)
metrics.update(mem_map)
# add norms
metrics.update(grad_norm_dic)
if self.__is_function_implemented('on_training_metrics'):
model.on_training_metrics(metrics)
# log metrics
scalar_metrics = self.__metrics_to_scalars(
metrics, blacklist=self.__log_vals_blacklist())
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step_num=self.global_step)
self.logger.save()
# logs user requested information to logger
self.__log_metrics(batch_step_metrics, grad_norm_dic)
# end epoch early
if early_stop_epoch:
@ -1106,6 +1087,32 @@ class Trainer(TrainerIO):
model = self.__get_model()
model.on_epoch_end()
def __log_metrics(self, metrics, grad_norm_dic):
"""
Logs the metric dict passed in
:param metrics:
:param grad_norm_dic:
:return:
"""
# added metrics by Lightning for convenience
metrics['epoch'] = self.current_epoch
# add gpu memory
if self.on_gpu and self.log_gpu_memory:
mem_map = memory.get_memory_profile()
metrics.update(mem_map)
# add norms
metrics.update(grad_norm_dic)
# turn all tensors to scalars
scalar_metrics = self.__metrics_to_scalars(metrics)
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step_num=self.global_step)
self.logger.save()
def test(self, model=None):
if model is not None:
self.testing = True
@ -1113,7 +1120,7 @@ class Trainer(TrainerIO):
else:
self.__run_evaluation(test=True)
def __metrics_to_scalars(self, metrics, blacklist=set()):
def __metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
@ -1122,9 +1129,6 @@ class Trainer(TrainerIO):
if type(v) is dict:
v = self.__metrics_to_scalars(v)
if k not in blacklist:
new_metrics[k] = float(v)
return new_metrics
def __log_vals_blacklist(self):
@ -1193,41 +1197,64 @@ class Trainer(TrainerIO):
else:
output = self.model.training_step(*args)
# ---------------
# TQDM metrics
# ---------------
# format and reduce outputs accordingly
loss, progress_bar_metrics, log_metrics = self.__process_output(output, train=True)
return loss, progress_bar_metrics, log_metrics
def __process_output(self, output, train=False):
"""
Reduces output according to the training mode.
Separates loss from logging and tqdm metrics
:param output:
:return:
"""
try:
progress_output = output['progress']
progress_output = output['progress_bar']
# reduce progress metrics for tqdm when using dp
if self.use_dp or self.use_ddp2:
if train and self.use_dp or self.use_ddp2:
nb_gpus = self.num_gpus
progress_output = reduce_distributed_output(progress_output, nb_gpus)
model_specific_tqdm_metrics_dic = progress_output
progress_bar_metrics = progress_output
except Exception:
model_specific_tqdm_metrics_dic = {}
progress_bar_metrics = {}
# extract metrics to log to experiment
try:
log_output = output['log']
# reduce progress metrics for tqdm when using dp
if train and self.use_dp or self.use_ddp2:
nb_gpus = self.num_gpus
log_output = reduce_distributed_output(log_output, nb_gpus)
log_metrics = log_output
except Exception:
log_metrics = {}
# ---------------
# EXTRACT LOSS
# ---------------
# if output dict doesn't have the keyword loss
# then assume the output=loss if scalar
try:
loss = output['loss']
except Exception:
if type(output) is torch.Tensor:
loss = output
else:
raise RuntimeError(
'No `loss` value in the dictionary returned from `model.training_step()`.'
)
loss = None
if train:
try:
loss = output['loss']
except Exception:
if type(output) is torch.Tensor:
loss = output
else:
raise RuntimeError(
'No `loss` value in the dictionary returned from `model.training_step()`.'
)
# when using dp need to reduce the loss
if self.use_dp or self.use_ddp2:
loss = reduce_distributed_output(loss, self.num_gpus)
# when using dp need to reduce the loss
if self.use_dp or self.use_ddp2:
loss = reduce_distributed_output(loss, self.num_gpus)
return loss, model_specific_tqdm_metrics_dic
return loss, progress_bar_metrics, log_metrics
def __clip_gradients(self):
if self.gradient_clip_val > 0:
@ -1244,6 +1271,9 @@ class Trainer(TrainerIO):
# track grad norms
grad_norm_dic = {}
# track metrics to log
all_log_metrics = []
if batch is None:
return 0, grad_norm_dic
@ -1265,10 +1295,12 @@ class Trainer(TrainerIO):
def optimizer_closure():
# forward pass
output = self.__training_forward(batch, batch_nb, opt_idx)
closure_loss, model_specific_tqdm_metrics = output
closure_loss, progress_bar_metrics, log_metrics = output
# track metrics
self.__add_tqdm_metrics(model_specific_tqdm_metrics)
# track progress bar metrics
self.__add_tqdm_metrics(progress_bar_metrics)
all_log_metrics.append(log_metrics)
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
@ -1321,7 +1353,7 @@ class Trainer(TrainerIO):
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
# update progressbar
# update progress bar
if self.show_progress_bar:
# add model specific metrics
tqdm_metrics = self.__training_tqdm_dict
@ -1332,7 +1364,10 @@ class Trainer(TrainerIO):
model = self.__get_model()
model.on_batch_end()
return 0, grad_norm_dic
# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
return 0, grad_norm_dic, all_log_metrics
def __run_evaluation(self, test=False):
# when testing make sure user defined a test step
@ -1367,11 +1402,19 @@ class Trainer(TrainerIO):
if self.fast_dev_run:
max_batches = 1
eval_out_metrics = self.evaluate(self.model,
dataloaders,
max_batches,
test)
self.__add_tqdm_metrics(eval_out_metrics)
# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
_, progress_bar_metrics, log_metrics = self.__process_output(eval_results)
# add metrics to prog bar
self.__add_tqdm_metrics(progress_bar_metrics)
# log metrics
self.__log_metrics(log_metrics, {})
# hook
model.on_post_performance_check()

View File

@ -14,6 +14,7 @@ from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import pdb
from . import test_models
class CoolModel(pl.LightningModule):
@ -59,156 +60,6 @@ class CoolModel(pl.LightningModule):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
def get_model(use_test_model=False):
# set up model with these hyperparams
hparams = get_hparams()
if use_test_model:
model = LightningTestModel(hparams)
else:
model = LightningTemplateModel(hparams)
return model, hparams
def get_exp(debug=True, version=None):
# set up exp object without actually saving logs
root_dir = os.path.dirname(os.path.realpath(__file__))
save_dir = os.path.join(root_dir, 'save_dir')
exp = Experiment(debug=debug, save_dir=save_dir, name='tests_tt_dir', version=version)
return exp
def init_save_dir():
root_dir = os.path.dirname(os.path.realpath(__file__))
save_dir = os.path.join(root_dir, 'save_dir')
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
os.makedirs(save_dir, exist_ok=True)
return save_dir
def clear_save_dir():
root_dir = os.path.dirname(os.path.realpath(__file__))
save_dir = os.path.join(root_dir, 'save_dir')
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningTemplateModel):
# load trained model
tags_path = exp.get_data_path(exp.name, exp.version)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
weights_dir = os.path.join(save_dir, checkpoints[0])
trained_model = module_class.load_from_metrics(weights_path=weights_dir,
tags_csv=tags_path,
on_gpu=on_gpu,
)
assert trained_model is not None, 'loading model failed'
return trained_model
def run_prediction(dataloader, trained_model):
# run prediction on 1 batch
for batch in dataloader:
break
x, y = batch
x = x.view(x.size(0), -1)
y_hat = trained_model(x)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
val_acc = val_acc.item()
assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set (it got %f)' % val_acc
# ------------------------------------------------------------------------
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)
# add these to the trainer options
trainer_options['checkpoint_callback'] = checkpoint
trainer_options['experiment'] = exp
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'amp + ddp model failed sto complete'
# test model loading
pretrained_model = load_model(exp, save_dir, on_gpu)
# test new model accuracy
run_prediction(model.test_dataloader, pretrained_model)
if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
# test HPC loading / saving
trainer.hpc_save(save_dir, exp)
trainer.hpc_load(save_dir, on_gpu=on_gpu)
clear_save_dir()
def assert_ok_val_acc(trainer):
# this model should get 0.80+ acc
acc = trainer.training_tqdm_dict['val_acc']
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
def assert_ok_test_acc(trainer):
# this model should get 0.80+ acc
acc = trainer.training_tqdm_dict['test_acc']
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
def get_hparams(continue_training=False, hpc_exp_number=0):
root_dir = os.path.dirname(os.path.realpath(__file__))
args = {
'drop_prob': 0.2,
'batch_size': 32,
'in_features': 28 * 28,
'learning_rate': 0.001 * 8,
'optimizer_name': 'adam',
'data_root': os.path.join(root_dir, 'mnist'),
'out_features': 10,
'hidden_dim': 1000}
if continue_training:
args['test_tube_do_checkpoint_load'] = True
args['hpc_exp_number'] = hpc_exp_number
hparams = Namespace(**args)
return hparams
def main():
"""
Make sure DDP + AMP continue training correctly
@ -218,19 +69,45 @@ def main():
Make sure DDP2 works
:return:
"""
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
model, hparams = get_model()
hparams = test_models.get_hparams()
model = LightningTestModel(hparams)
save_dir = test_models.init_save_dir()
# logger file to get meta
logger = test_models.get_test_tube_logger(False)
logger.log_hyperparams(hparams)
logger.save()
# logger file to get weights
checkpoint = ModelCheckpoint(save_dir)
trainer_options = dict(
show_progress_bar=True,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
gpus=2,
print_weights_summary=True,
distributed_backend='ddp2'
checkpoint_callback=checkpoint,
logger=logger,
gpus=[0, 1],
distributed_backend='dp'
)
run_gpu_model_test(trainer_options, model, hparams)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = test_models.load_model(logger.experiment, save_dir,
module_class=LightningTestModel)
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)
# test we have good test accuracy
test_models.assert_ok_test_acc(new_trainer)
test_models.clear_save_dir()
if __name__ == '__main__':

View File

@ -401,7 +401,7 @@ def test_running_test_pretrained_model_dp():
checkpoint = ModelCheckpoint(save_dir)
trainer_options = dict(
show_progress_bar=False,
show_progress_bar=True,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
@ -615,7 +615,7 @@ def test_early_stopping_cpu_model():
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
show_progress_bar=False,
show_progress_bar=True,
logger=get_test_tube_logger(),
train_percent_check=0.1,
val_percent_check=0.1