2020-09-02 02:06:15 +00:00
import subprocess
2020-09-01 22:03:28 +00:00
import numpy as np
2020-09-02 02:06:15 +00:00
import torch
import torch.distributed as torch_distrib
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.supporters import Accumulator
from pytorch_lightning.callbacks import ModelCheckpoint
2020-09-02 01:06:40 +00:00
from pytorch_lightning.core.step_result import Result
2020-09-02 02:06:15 +00:00
from pytorch_lightning import _logger as log
2020-09-01 22:03:28 +00:00
class TrainLoop:
def __init__(self, trainer):
self.trainer = trainer
self.should_check_val = False
self.early_stopping_accumulator = None
self.checkpoint_accumulator = None
2020-09-02 02:06:15 +00:00
self._teardown_already_run = False
2020-09-01 22:03:28 +00:00
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
return num_optimizers
2020-09-02 02:06:15 +00:00
def on_train_start(self):
# clear cache before training
if self.trainer.on_gpu and self.trainer.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.trainer.root_gpu}'):
# hook
def on_train_end(self):
if self._teardown_already_run:
self._teardown_already_run = True
# Save latest checkpoint
log.info('Saving latest checkpoint..')
# hook
# kill loggers
if self.trainer.logger is not None:
# summarize profile results
if self.trainer.global_rank == 0:
if self.trainer.global_rank == 0:
for proc in self.trainer.interactive_ddp_procs:
# clean up dist group
if self.trainer.use_ddp or self.trainer.use_ddp2:
# clear mem
if self.trainer.on_gpu:
model = self.trainer.get_model()
def check_checkpoint_callback(self, should_check_val):
model = self.trainer.get_model()
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
should_activate = not is_overridden('validation_step', model) and not should_check_val
if should_activate:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
2020-09-01 22:03:28 +00:00
def on_train_epoch_start(self):
# hook
# bookkeeping
self.should_check_val = False
# structured result accumulators for callbacks
self.early_stopping_accumulator = Accumulator()
self.checkpoint_accumulator = Accumulator()
2020-09-02 02:06:15 +00:00
2020-09-02 01:06:40 +00:00
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
# hook
self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)
2020-09-02 02:06:15 +00:00
def reset_train_val_dataloaders(self, model):
if not self.trainer.reload_dataloaders_every_epoch:
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
2020-09-02 01:06:40 +00:00
def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
2020-09-01 22:03:28 +00:00
def get_optimizers_iterable(self):
Generates an iterable with (idx, optimizer) for each optimizer.
if not self.trainer.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.trainer.optimizers))
optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length
# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.trainer.optimizers[opt_idx])]