decouple returns from each step (#307)
* decoupled training metrics from logging metrics * decoupled validation metrics from log metrics * updated docs * updated docs * updated docs * Fixed test * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master * merged master
This commit is contained in:
parent
8f5a06bfb8
commit
6cc3f1757f
|
@ -129,7 +129,8 @@ Dictionary or OrderedDict
|
|||
| key | value | is required |
|
||||
|---|---|---|
|
||||
| loss | tensor scalar | Y |
|
||||
| progress | Dict for progress bar display. Must have only tensors | N |
|
||||
| progress_bar | Dict for progress bar display. Must have only tensors | N |
|
||||
| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
|
||||
|
||||
|
||||
**Example**
|
||||
|
@ -144,7 +145,8 @@ def training_step(self, batch, batch_nb):
|
|||
|
||||
output = {
|
||||
'loss': loss, # required
|
||||
'progress': {'training_loss': loss} # optional (MUST ALL BE TENSORS)
|
||||
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
|
||||
'log': {'training_loss': loss} # optional (MUST ALL BE TENSORS)
|
||||
}
|
||||
|
||||
# return a dict
|
||||
|
@ -161,6 +163,9 @@ def training_step(self, batch, batch_nb, optimizer_idx):
|
|||
# do training_step with decoder
|
||||
```
|
||||
|
||||
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
|
||||
break out of the current training epoch early.
|
||||
|
||||
---
|
||||
### train_dataloader
|
||||
|
||||
|
@ -263,7 +268,7 @@ The dict you return here will be available in the `validation_end` method.
|
|||
|
||||
| Return | description | optional |
|
||||
|---|---|---|
|
||||
| dict | Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors. | Y |
|
||||
| dict | Dict or OrderedDict - passed to the validation_end step | N |
|
||||
|
||||
**Example**
|
||||
|
||||
|
@ -327,9 +332,12 @@ The outputs here are strictly for the progress bar. If you don't need to display
|
|||
|
||||
**Return**
|
||||
|
||||
| Return | description | optional |
|
||||
|---|---|---|
|
||||
| dict | Dict of OrderedDict with metrics to display in progress bar | Y |
|
||||
Dictionary or OrderedDict
|
||||
|
||||
| key | value | is required |
|
||||
|---|---|---|
|
||||
| progress_bar | Dict for progress bar display. Must have only tensors | N |
|
||||
| log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
|
||||
|
||||
**Example**
|
||||
|
||||
|
@ -351,7 +359,13 @@ def validation_end(self, outputs):
|
|||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_loss': val_loss_mean.item()}
|
||||
}
|
||||
return results
|
||||
```
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
|
@ -377,7 +391,13 @@ def validation_end(self, outputs):
|
|||
val_loss_mean /= i
|
||||
val_acc_mean /= i
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_loss': val_loss_mean.item()}
|
||||
}
|
||||
return results
|
||||
```
|
||||
|
||||
### test_step
|
||||
|
@ -490,7 +510,13 @@ def test_end(self, outputs):
|
|||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
# show test_loss and test_acc in progress bar but only log test_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'test_loss': val_loss_mean.item()}
|
||||
}
|
||||
return results
|
||||
```
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
|
@ -516,7 +542,13 @@ def test_end(self, outputs):
|
|||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
|
||||
# show test_loss and test_acc in progress bar but only log test_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'test_loss': val_loss_mean.item()}
|
||||
}
|
||||
return results
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -58,16 +58,6 @@ def on_post_performance_check(self):
|
|||
```
|
||||
|
||||
---
|
||||
#### on_training_metrics
|
||||
Called in the training loop, right before metrics are logged.
|
||||
Although you can log at any time by using self.experiment, you can use
|
||||
this callback to modify what will be logged.
|
||||
```python
|
||||
def on_training_metrics(self, metrics):
|
||||
# do something before validation end
|
||||
```
|
||||
|
||||
---
|
||||
#### optimizer_step
|
||||
Calls .step() and .zero_grad for each optimizer.
|
||||
You can override this method to adjust how you do the optimizer step for each optimizer
|
||||
|
|
|
@ -168,7 +168,8 @@ class LightningTemplateModel(LightningModule):
|
|||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
|
||||
return tqdm_dict
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
return result
|
||||
|
||||
# ---------------------
|
||||
# TRAINING SETUP
|
||||
|
|
|
@ -28,9 +28,6 @@ class ModelHooks(torch.nn.Module):
|
|||
def on_post_performance_check(self):
|
||||
pass
|
||||
|
||||
def on_training_metrics(self, metrics):
|
||||
pass
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
"""
|
||||
Called after optimizer.step() and before optimizer.zero_grad()
|
||||
|
|
|
@ -104,8 +104,9 @@ class LightningTestModelBase(LightningModule):
|
|||
if self.trainer.batch_nb % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
'progress': {'some_val': loss_val * loss_val}
|
||||
'progress_bar': {'some_val': loss_val * loss_val}
|
||||
})
|
||||
|
||||
return output
|
||||
if self.trainer.batch_nb % 2 == 0:
|
||||
return loss_val
|
||||
|
|
|
@ -105,7 +105,8 @@ class LightningValidationMixin(LightningValidationStepMixin):
|
|||
val_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
results = {'progress_bar': tqdm_dict}
|
||||
return results
|
||||
|
||||
|
||||
class LightningValidationStepMultipleDataloadersMixin:
|
||||
|
@ -207,7 +208,8 @@ class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipl
|
|||
val_acc_mean /= i
|
||||
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
return result
|
||||
|
||||
|
||||
class LightningTestStepMixin:
|
||||
|
@ -291,7 +293,8 @@ class LightningTestMixin(LightningTestStepMixin):
|
|||
test_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
return result
|
||||
|
||||
|
||||
class LightningTestStepMultipleDataloadersMixin:
|
||||
|
@ -384,4 +387,5 @@ class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloaders
|
|||
test_acc_mean /= i
|
||||
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
return tqdm_dict
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
return result
|
||||
|
|
|
@ -553,7 +553,6 @@ class Trainer(TrainerIO):
|
|||
:param model: PT model
|
||||
:param dataloaders: list of PT dataloaders
|
||||
:param max_batches: Scalar
|
||||
:param dataloader_idx:
|
||||
:param test: boolean
|
||||
:return:
|
||||
"""
|
||||
|
@ -582,7 +581,10 @@ class Trainer(TrainerIO):
|
|||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
output = self.__evaluation_forward(model, batch, batch_idx, dataloader_idx,
|
||||
output = self.__evaluation_forward(model,
|
||||
batch,
|
||||
batch_idx,
|
||||
dataloader_idx,
|
||||
test)
|
||||
|
||||
# track outputs for collation
|
||||
|
@ -704,8 +706,6 @@ class Trainer(TrainerIO):
|
|||
task = int(os.environ['SLURM_LOCALID'])
|
||||
self.ddp_train(task, model)
|
||||
else:
|
||||
nb_gpus = self.nb_requested_gpus
|
||||
nb_tasks = self.nb_slurm_tasks
|
||||
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))
|
||||
|
||||
# 1 gpu or dp option triggers training using DP module
|
||||
|
@ -1054,7 +1054,8 @@ class Trainer(TrainerIO):
|
|||
# ---------------
|
||||
# RUN TRAIN STEP
|
||||
# ---------------
|
||||
batch_result, grad_norm_dic = self.__run_training_batch(batch, batch_nb)
|
||||
output = self.__run_training_batch(batch, batch_nb)
|
||||
batch_result, grad_norm_dic, batch_step_metrics = output
|
||||
early_stop_epoch = batch_result == -1
|
||||
|
||||
# ---------------
|
||||
|
@ -1073,29 +1074,9 @@ class Trainer(TrainerIO):
|
|||
|
||||
# when metrics should be logged
|
||||
if batch_nb % self.row_log_interval == 0 or early_stop_epoch:
|
||||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
model = self.__get_model()
|
||||
metrics = self.__training_tqdm_dict
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu and self.log_gpu_memory is not None:
|
||||
mem_map = memory.get_memory_profile(mode=self.log_gpu_memory)
|
||||
metrics.update(mem_map)
|
||||
|
||||
# add norms
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
if self.__is_function_implemented('on_training_metrics'):
|
||||
model.on_training_metrics(metrics)
|
||||
|
||||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(
|
||||
metrics, blacklist=self.__log_vals_blacklist())
|
||||
if self.proc_rank == 0 and self.logger is not None:
|
||||
self.logger.log_metrics(scalar_metrics, step_num=self.global_step)
|
||||
self.logger.save()
|
||||
# logs user requested information to logger
|
||||
self.__log_metrics(batch_step_metrics, grad_norm_dic)
|
||||
|
||||
# end epoch early
|
||||
if early_stop_epoch:
|
||||
|
@ -1106,6 +1087,32 @@ class Trainer(TrainerIO):
|
|||
model = self.__get_model()
|
||||
model.on_epoch_end()
|
||||
|
||||
def __log_metrics(self, metrics, grad_norm_dic):
|
||||
"""
|
||||
Logs the metric dict passed in
|
||||
:param metrics:
|
||||
:param grad_norm_dic:
|
||||
:return:
|
||||
"""
|
||||
# added metrics by Lightning for convenience
|
||||
metrics['epoch'] = self.current_epoch
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu and self.log_gpu_memory:
|
||||
mem_map = memory.get_memory_profile()
|
||||
metrics.update(mem_map)
|
||||
|
||||
# add norms
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
# turn all tensors to scalars
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics)
|
||||
|
||||
# log actual metrics
|
||||
if self.proc_rank == 0 and self.logger is not None:
|
||||
self.logger.log_metrics(scalar_metrics, step_num=self.global_step)
|
||||
self.logger.save()
|
||||
|
||||
def test(self, model=None):
|
||||
if model is not None:
|
||||
self.testing = True
|
||||
|
@ -1113,7 +1120,7 @@ class Trainer(TrainerIO):
|
|||
else:
|
||||
self.__run_evaluation(test=True)
|
||||
|
||||
def __metrics_to_scalars(self, metrics, blacklist=set()):
|
||||
def __metrics_to_scalars(self, metrics):
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
|
@ -1122,9 +1129,6 @@ class Trainer(TrainerIO):
|
|||
if type(v) is dict:
|
||||
v = self.__metrics_to_scalars(v)
|
||||
|
||||
if k not in blacklist:
|
||||
new_metrics[k] = float(v)
|
||||
|
||||
return new_metrics
|
||||
|
||||
def __log_vals_blacklist(self):
|
||||
|
@ -1193,41 +1197,64 @@ class Trainer(TrainerIO):
|
|||
else:
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
# ---------------
|
||||
# TQDM metrics
|
||||
# ---------------
|
||||
# format and reduce outputs accordingly
|
||||
loss, progress_bar_metrics, log_metrics = self.__process_output(output, train=True)
|
||||
return loss, progress_bar_metrics, log_metrics
|
||||
|
||||
def __process_output(self, output, train=False):
|
||||
"""
|
||||
Reduces output according to the training mode.
|
||||
Separates loss from logging and tqdm metrics
|
||||
:param output:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
progress_output = output['progress']
|
||||
progress_output = output['progress_bar']
|
||||
|
||||
# reduce progress metrics for tqdm when using dp
|
||||
if self.use_dp or self.use_ddp2:
|
||||
if train and self.use_dp or self.use_ddp2:
|
||||
nb_gpus = self.num_gpus
|
||||
progress_output = reduce_distributed_output(progress_output, nb_gpus)
|
||||
|
||||
model_specific_tqdm_metrics_dic = progress_output
|
||||
progress_bar_metrics = progress_output
|
||||
except Exception:
|
||||
model_specific_tqdm_metrics_dic = {}
|
||||
progress_bar_metrics = {}
|
||||
|
||||
# extract metrics to log to experiment
|
||||
try:
|
||||
log_output = output['log']
|
||||
|
||||
# reduce progress metrics for tqdm when using dp
|
||||
if train and self.use_dp or self.use_ddp2:
|
||||
nb_gpus = self.num_gpus
|
||||
log_output = reduce_distributed_output(log_output, nb_gpus)
|
||||
|
||||
log_metrics = log_output
|
||||
except Exception:
|
||||
log_metrics = {}
|
||||
|
||||
# ---------------
|
||||
# EXTRACT LOSS
|
||||
# ---------------
|
||||
# if output dict doesn't have the keyword loss
|
||||
# then assume the output=loss if scalar
|
||||
try:
|
||||
loss = output['loss']
|
||||
except Exception:
|
||||
if type(output) is torch.Tensor:
|
||||
loss = output
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'No `loss` value in the dictionary returned from `model.training_step()`.'
|
||||
)
|
||||
loss = None
|
||||
if train:
|
||||
try:
|
||||
loss = output['loss']
|
||||
except Exception:
|
||||
if type(output) is torch.Tensor:
|
||||
loss = output
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'No `loss` value in the dictionary returned from `model.training_step()`.'
|
||||
)
|
||||
|
||||
# when using dp need to reduce the loss
|
||||
if self.use_dp or self.use_ddp2:
|
||||
loss = reduce_distributed_output(loss, self.num_gpus)
|
||||
# when using dp need to reduce the loss
|
||||
if self.use_dp or self.use_ddp2:
|
||||
loss = reduce_distributed_output(loss, self.num_gpus)
|
||||
|
||||
return loss, model_specific_tqdm_metrics_dic
|
||||
return loss, progress_bar_metrics, log_metrics
|
||||
|
||||
def __clip_gradients(self):
|
||||
if self.gradient_clip_val > 0:
|
||||
|
@ -1244,6 +1271,9 @@ class Trainer(TrainerIO):
|
|||
# track grad norms
|
||||
grad_norm_dic = {}
|
||||
|
||||
# track metrics to log
|
||||
all_log_metrics = []
|
||||
|
||||
if batch is None:
|
||||
return 0, grad_norm_dic
|
||||
|
||||
|
@ -1265,10 +1295,12 @@ class Trainer(TrainerIO):
|
|||
def optimizer_closure():
|
||||
# forward pass
|
||||
output = self.__training_forward(batch, batch_nb, opt_idx)
|
||||
closure_loss, model_specific_tqdm_metrics = output
|
||||
closure_loss, progress_bar_metrics, log_metrics = output
|
||||
|
||||
# track metrics
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics)
|
||||
# track progress bar metrics
|
||||
self.__add_tqdm_metrics(progress_bar_metrics)
|
||||
|
||||
all_log_metrics.append(log_metrics)
|
||||
|
||||
# accumulate loss
|
||||
# (if accumulate_grad_batches = 1 no effect)
|
||||
|
@ -1321,7 +1353,7 @@ class Trainer(TrainerIO):
|
|||
self.batch_loss_value = 0
|
||||
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||
|
||||
# update progressbar
|
||||
# update progress bar
|
||||
if self.show_progress_bar:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__training_tqdm_dict
|
||||
|
@ -1332,7 +1364,10 @@ class Trainer(TrainerIO):
|
|||
model = self.__get_model()
|
||||
model.on_batch_end()
|
||||
|
||||
return 0, grad_norm_dic
|
||||
# collapse all metrics into one dict
|
||||
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
|
||||
|
||||
return 0, grad_norm_dic, all_log_metrics
|
||||
|
||||
def __run_evaluation(self, test=False):
|
||||
# when testing make sure user defined a test step
|
||||
|
@ -1367,11 +1402,19 @@ class Trainer(TrainerIO):
|
|||
if self.fast_dev_run:
|
||||
max_batches = 1
|
||||
|
||||
eval_out_metrics = self.evaluate(self.model,
|
||||
dataloaders,
|
||||
max_batches,
|
||||
test)
|
||||
self.__add_tqdm_metrics(eval_out_metrics)
|
||||
# run evaluation
|
||||
eval_results = self.evaluate(self.model,
|
||||
dataloaders,
|
||||
max_batches,
|
||||
test)
|
||||
|
||||
_, progress_bar_metrics, log_metrics = self.__process_output(eval_results)
|
||||
|
||||
# add metrics to prog bar
|
||||
self.__add_tqdm_metrics(progress_bar_metrics)
|
||||
|
||||
# log metrics
|
||||
self.__log_metrics(log_metrics, {})
|
||||
|
||||
# hook
|
||||
model.on_post_performance_check()
|
||||
|
|
189
tests/debug.py
189
tests/debug.py
|
@ -14,6 +14,7 @@ from torch.utils.data import DataLoader
|
|||
from torchvision.datasets import MNIST
|
||||
import numpy as np
|
||||
import pdb
|
||||
from . import test_models
|
||||
|
||||
|
||||
class CoolModel(pl.LightningModule):
|
||||
|
@ -59,156 +60,6 @@ class CoolModel(pl.LightningModule):
|
|||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
|
||||
|
||||
def get_model(use_test_model=False):
|
||||
# set up model with these hyperparams
|
||||
hparams = get_hparams()
|
||||
|
||||
if use_test_model:
|
||||
model = LightningTestModel(hparams)
|
||||
else:
|
||||
model = LightningTemplateModel(hparams)
|
||||
|
||||
return model, hparams
|
||||
|
||||
|
||||
def get_exp(debug=True, version=None):
|
||||
# set up exp object without actually saving logs
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
save_dir = os.path.join(root_dir, 'save_dir')
|
||||
exp = Experiment(debug=debug, save_dir=save_dir, name='tests_tt_dir', version=version)
|
||||
return exp
|
||||
|
||||
|
||||
def init_save_dir():
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
save_dir = os.path.join(root_dir, 'save_dir')
|
||||
|
||||
if os.path.exists(save_dir):
|
||||
shutil.rmtree(save_dir)
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
return save_dir
|
||||
|
||||
|
||||
def clear_save_dir():
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
save_dir = os.path.join(root_dir, 'save_dir')
|
||||
if os.path.exists(save_dir):
|
||||
shutil.rmtree(save_dir)
|
||||
|
||||
|
||||
def load_model(exp, save_dir, on_gpu, map_location=None, module_class=LightningTemplateModel):
|
||||
|
||||
# load trained model
|
||||
tags_path = exp.get_data_path(exp.name, exp.version)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
|
||||
checkpoints = [x for x in os.listdir(save_dir) if '.ckpt' in x]
|
||||
weights_dir = os.path.join(save_dir, checkpoints[0])
|
||||
|
||||
trained_model = module_class.load_from_metrics(weights_path=weights_dir,
|
||||
tags_csv=tags_path,
|
||||
on_gpu=on_gpu,
|
||||
)
|
||||
|
||||
assert trained_model is not None, 'loading model failed'
|
||||
|
||||
return trained_model
|
||||
|
||||
|
||||
def run_prediction(dataloader, trained_model):
|
||||
# run prediction on 1 batch
|
||||
for batch in dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
y_hat = trained_model(x)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
val_acc = val_acc.item()
|
||||
assert val_acc > 0.70, 'this model is expected to get > 0.7 in test set (it got %f)' % val_acc
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
# add these to the trainer options
|
||||
trainer_options['checkpoint_callback'] = checkpoint
|
||||
trainer_options['experiment'] = exp
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + ddp model failed sto complete'
|
||||
|
||||
# test model loading
|
||||
pretrained_model = load_model(exp, save_dir, on_gpu)
|
||||
|
||||
# test new model accuracy
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
if trainer.use_ddp:
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(save_dir, exp)
|
||||
trainer.hpc_load(save_dir, on_gpu=on_gpu)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def assert_ok_val_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.training_tqdm_dict['val_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
def assert_ok_test_acc(trainer):
|
||||
# this model should get 0.80+ acc
|
||||
acc = trainer.training_tqdm_dict['test_acc']
|
||||
assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'
|
||||
|
||||
|
||||
def get_hparams(continue_training=False, hpc_exp_number=0):
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
args = {
|
||||
'drop_prob': 0.2,
|
||||
'batch_size': 32,
|
||||
'in_features': 28 * 28,
|
||||
'learning_rate': 0.001 * 8,
|
||||
'optimizer_name': 'adam',
|
||||
'data_root': os.path.join(root_dir, 'mnist'),
|
||||
'out_features': 10,
|
||||
'hidden_dim': 1000}
|
||||
|
||||
if continue_training:
|
||||
args['test_tube_do_checkpoint_load'] = True
|
||||
args['hpc_exp_number'] = hpc_exp_number
|
||||
|
||||
hparams = Namespace(**args)
|
||||
return hparams
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Make sure DDP + AMP continue training correctly
|
||||
|
@ -218,19 +69,45 @@ def main():
|
|||
Make sure DDP2 works
|
||||
:return:
|
||||
"""
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
model, hparams = get_model()
|
||||
hparams = test_models.get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = test_models.init_save_dir()
|
||||
|
||||
# logger file to get meta
|
||||
logger = test_models.get_test_tube_logger(False)
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
# logger file to get weights
|
||||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
gpus=2,
|
||||
print_weights_summary=True,
|
||||
distributed_backend='ddp2'
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='dp'
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = test_models.load_model(logger.experiment, save_dir,
|
||||
module_class=LightningTestModel)
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
test_models.assert_ok_test_acc(new_trainer)
|
||||
test_models.clear_save_dir()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -401,7 +401,7 @@ def test_running_test_pretrained_model_dp():
|
|||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
show_progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
|
@ -615,7 +615,7 @@ def test_early_stopping_cpu_model():
|
|||
overfit_pct=0.20,
|
||||
track_grad_norm=2,
|
||||
print_nan_grads=True,
|
||||
show_progress_bar=False,
|
||||
show_progress_bar=True,
|
||||
logger=get_test_tube_logger(),
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1
|
||||
|
|
Loading…
Reference in New Issue