diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 635355f0bd..859ffb682b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -176,10 +176,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.training_loop_temp import TrainLoop @@ -279,7 +277,6 @@ class TrainerTrainLoopMixin(ABC): on_epoch_end: Callable on_validation_end: Callable on_keyboard_interrupt: Callable - on_train_epoch_start: Callable on_train_epoch_end: Callable @abstractmethod @@ -351,30 +348,15 @@ class TrainerTrainLoopMixin(ABC): try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): + # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) - # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_horovod or self.on_tpu) \ - and hasattr(self.train_dataloader, 'sampler') \ - and hasattr(self.train_dataloader.sampler, 'set_epoch'): - self.train_dataloader.sampler.set_epoch(epoch) - # update training progress in trainer and model - model.current_epoch = epoch - self.current_epoch = epoch + # hook + self.train_loop.on_train_epoch_start(epoch) - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_start(self, self.get_model()) - - # stores accumulated grad fractions per batch - self.batch_loss_value = TensorRunningAccum( - window_length=self.accumulate_grad_batches - ) - - # ----------------- - # RUN TNG EPOCH - # ----------------- + # run train epoch self.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: @@ -419,9 +401,6 @@ class TrainerTrainLoopMixin(ABC): # get model model = self.get_model() - # hook - self.train_loop.on_train_epoch_start() - # modify dataloader if needed (ddp, etc...) train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader) diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 82e14ca3d8..a43d9c256c 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -3,7 +3,7 @@ import numpy as np import torch import torch.distributed as torch_distrib from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.trainer.supporters import Accumulator +from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import _logger as log from pytorch_lightning.utilities.memory import recursive_detach @@ -83,10 +83,27 @@ class TrainLoop: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] [c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks] - def on_train_epoch_start(self): - # hook - self.trainer.call_hook('on_epoch_start') - self.trainer.call_hook('on_train_epoch_start') + def on_train_epoch_start(self, epoch): + model = self.trainer.get_model() + + # set seed for distributed sampler (enables shuffling for each epoch) + # TODO: move to accelerators + if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \ + and hasattr(self.trainer.train_dataloader, 'sampler') \ + and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'): + self.trainer.train_dataloader.sampler.set_epoch(epoch) + + # update training progress in trainer and model + model.current_epoch = epoch + self.trainer.current_epoch = epoch + + # changing gradient according accumulation_scheduler + self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model()) + + # stores accumulated grad fractions per batch + self.trainer.batch_loss_value = TensorRunningAccum( + window_length=self.trainer.accumulate_grad_batches + ) # bookkeeping self.should_check_val = False @@ -95,6 +112,10 @@ class TrainLoop: self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() + # hook + self.trainer.call_hook('on_epoch_start') + self.trainer.call_hook('on_train_epoch_start') + def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)