extract training teardown into method, catch KeyboardInterrupt (#856)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
446a1e23d7
commit
e05586c4b2
|
@ -1060,9 +1060,6 @@ class Trainer(TrainerIOMixin,
|
|||
# CORE TRAINING LOOP
|
||||
self.train()
|
||||
|
||||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
def test(self, model=None):
|
||||
r"""
|
||||
|
||||
|
|
|
@ -155,6 +155,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b
|
|||
import copy
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -307,98 +308,95 @@ class TrainerTrainLoopMixin(ABC):
|
|||
def train(self):
|
||||
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
# get model
|
||||
model = self.get_model()
|
||||
# run all epochs
|
||||
for epoch in range(self.current_epoch, self.max_epochs):
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
if (self.use_ddp or self.use_tpu) \
|
||||
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
|
||||
self.get_train_dataloader().sampler.set_epoch(epoch)
|
||||
try:
|
||||
# run all epochs
|
||||
for epoch in range(self.current_epoch, self.max_epochs):
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
if (self.use_ddp or self.use_tpu) \
|
||||
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
|
||||
self.get_train_dataloader().sampler.set_epoch(epoch)
|
||||
|
||||
# get model
|
||||
model = self.get_model()
|
||||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
self.current_epoch = epoch
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
self.current_epoch = epoch
|
||||
|
||||
total_val_batches = 0
|
||||
is_val_epoch = False
|
||||
if not self.disable_validation:
|
||||
# val can be checked multiple times in epoch
|
||||
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
||||
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
|
||||
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
|
||||
total_val_batches = self.num_val_batches * val_checks_per_epoch
|
||||
total_val_batches = 0
|
||||
is_val_epoch = False
|
||||
if not self.disable_validation:
|
||||
# val can be checked multiple times in epoch
|
||||
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
||||
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
|
||||
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
|
||||
total_val_batches = self.num_val_batches * val_checks_per_epoch
|
||||
|
||||
# total batches includes multiple val checks
|
||||
self.total_batches = self.num_training_batches + total_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
# total batches includes multiple val checks
|
||||
self.total_batches = self.num_training_batches + total_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
|
||||
if self.fast_dev_run:
|
||||
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
|
||||
num_iterations = 2
|
||||
elif self.is_iterable_train_dataloader:
|
||||
# for iterable train loader, the progress bar never ends
|
||||
num_iterations = None
|
||||
else:
|
||||
num_iterations = self.total_batches
|
||||
if self.fast_dev_run:
|
||||
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
|
||||
num_iterations = 2
|
||||
elif self.is_iterable_train_dataloader:
|
||||
# for iterable train loader, the progress bar never ends
|
||||
num_iterations = None
|
||||
else:
|
||||
num_iterations = self.total_batches
|
||||
|
||||
# reset progress bar
|
||||
# .reset() doesn't work on disabled progress bar so we should check
|
||||
if not self.main_progress_bar.disable:
|
||||
self.main_progress_bar.reset(num_iterations)
|
||||
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
|
||||
self.main_progress_bar.set_description(desc)
|
||||
# reset progress bar
|
||||
# .reset() doesn't work on disabled progress bar so we should check
|
||||
if not self.main_progress_bar.disable:
|
||||
self.main_progress_bar.reset(num_iterations)
|
||||
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
|
||||
self.main_progress_bar.set_description(desc)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.accumulation_scheduler.on_epoch_begin()
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.accumulation_scheduler.on_epoch_begin()
|
||||
|
||||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
# -----------------
|
||||
self.run_training_epoch()
|
||||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
# -----------------
|
||||
self.run_training_epoch()
|
||||
|
||||
# update LR schedulers
|
||||
if self.lr_schedulers is not None:
|
||||
for lr_scheduler in self.lr_schedulers:
|
||||
lr_scheduler.step()
|
||||
if self.reduce_lr_on_plateau_scheduler is not None:
|
||||
val_loss = self.callback_metrics.get('val_loss')
|
||||
if val_loss is None:
|
||||
avail_metrics = ','.join(list(self.callback_metrics.keys()))
|
||||
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
|
||||
f'which is not available. Available metrics are: {avail_metrics}'
|
||||
raise MisconfigurationException(m)
|
||||
self.reduce_lr_on_plateau_scheduler.step(val_loss)
|
||||
# update LR schedulers
|
||||
if self.lr_schedulers is not None:
|
||||
for lr_scheduler in self.lr_schedulers:
|
||||
lr_scheduler.step()
|
||||
if self.reduce_lr_on_plateau_scheduler is not None:
|
||||
val_loss = self.callback_metrics.get('val_loss')
|
||||
if val_loss is None:
|
||||
avail_metrics = ','.join(list(self.callback_metrics.keys()))
|
||||
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
|
||||
f'which is not available. Available metrics are: {avail_metrics}'
|
||||
raise MisconfigurationException(m)
|
||||
self.reduce_lr_on_plateau_scheduler.step(val_loss)
|
||||
|
||||
if self.max_steps and self.max_steps == self.global_step:
|
||||
self.main_progress_bar.close()
|
||||
model.on_train_end()
|
||||
return
|
||||
|
||||
# early stopping
|
||||
met_min_epochs = epoch >= self.min_epochs - 1
|
||||
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
||||
|
||||
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
|
||||
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
|
||||
should_stop = self.early_stop_callback.on_epoch_end()
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
if self.max_steps and self.max_steps == self.global_step:
|
||||
self.main_progress_bar.close()
|
||||
with self.profiler.profile('on_train_end'):
|
||||
model.on_train_end()
|
||||
model.on_train_end()
|
||||
return
|
||||
|
||||
self.main_progress_bar.close()
|
||||
# early stopping
|
||||
met_min_epochs = epoch >= self.min_epochs - 1
|
||||
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
||||
|
||||
with self.profiler.profile('on_train_end'):
|
||||
model.on_train_end()
|
||||
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
|
||||
((met_min_epochs and met_min_steps) or self.fast_dev_run)):
|
||||
should_stop = self.early_stop_callback.on_epoch_end()
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
self.run_training_teardown()
|
||||
return
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.finalize("success")
|
||||
self.run_training_teardown()
|
||||
except KeyboardInterrupt:
|
||||
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
|
||||
self.run_training_teardown()
|
||||
|
||||
def run_training_epoch(self):
|
||||
# before epoch hook
|
||||
|
@ -622,6 +620,20 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
return 0, grad_norm_dic, all_log_metrics
|
||||
|
||||
def run_training_teardown(self):
|
||||
model = self.get_model()
|
||||
|
||||
self.main_progress_bar.close()
|
||||
|
||||
with self.profiler.profile('on_train_end'):
|
||||
model.on_train_end()
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.finalize("success")
|
||||
|
||||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
||||
"""
|
||||
Handle forward for each training case (distributed, single gpu, etc...)
|
||||
|
|
Loading…
Reference in New Issue