extract training teardown into method, catch KeyboardInterrupt (#856)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Jeremy Jordan 2020-02-22 17:06:48 -05:00 committed by GitHub
parent 446a1e23d7
commit e05586c4b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 79 deletions

View File

@ -1060,9 +1060,6 @@ class Trainer(TrainerIOMixin,
# CORE TRAINING LOOP
self.train()
# summarize profile results
self.profiler.describe()
def test(self, model=None):
r"""

View File

@ -155,6 +155,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b
import copy
import warnings
from abc import ABC, abstractmethod
import logging as log
import numpy as np
@ -307,98 +308,95 @@ class TrainerTrainLoopMixin(ABC):
def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
# get model
model = self.get_model()
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)
# get model
model = self.get_model()
# get model
model = self.get_model()
# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch
total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch
# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads
# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)
# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()
# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()
# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()
# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss)
# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss)
if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
model.on_train_end()
return
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
model.on_train_end()
return
self.main_progress_bar.close()
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
with self.profiler.profile('on_train_end'):
model.on_train_end()
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
return
if self.logger is not None:
self.logger.finalize("success")
self.run_training_teardown()
except KeyboardInterrupt:
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
self.run_training_teardown()
def run_training_epoch(self):
# before epoch hook
@ -622,6 +620,20 @@ class TrainerTrainLoopMixin(ABC):
return 0, grad_norm_dic, all_log_metrics
def run_training_teardown(self):
model = self.get_model()
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
if self.logger is not None:
self.logger.finalize("success")
# summarize profile results
self.profiler.describe()
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)