ref: inner train loop (intermediate step) 11/n (#3370)
* ref: inner train loop (intermediate step) 11/n * ref: inner train loop (intermediate step) 11/n
This commit is contained in:
parent
85421466ab
commit
8eef97c76f
|
@ -176,10 +176,8 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
|
||||
|
@ -279,7 +277,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
on_epoch_end: Callable
|
||||
on_validation_end: Callable
|
||||
on_keyboard_interrupt: Callable
|
||||
on_train_epoch_start: Callable
|
||||
on_train_epoch_end: Callable
|
||||
|
||||
@abstractmethod
|
||||
|
@ -351,30 +348,15 @@ class TrainerTrainLoopMixin(ABC):
|
|||
try:
|
||||
# run all epochs
|
||||
for epoch in range(self.current_epoch, self.max_epochs):
|
||||
|
||||
# reset train dataloader
|
||||
if self.reload_dataloaders_every_epoch:
|
||||
self.reset_train_dataloader(model)
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
if (self.use_ddp or self.use_horovod or self.on_tpu) \
|
||||
and hasattr(self.train_dataloader, 'sampler') \
|
||||
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
|
||||
self.train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
self.current_epoch = epoch
|
||||
# hook
|
||||
self.train_loop.on_train_epoch_start(epoch)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
|
||||
|
||||
# stores accumulated grad fractions per batch
|
||||
self.batch_loss_value = TensorRunningAccum(
|
||||
window_length=self.accumulate_grad_batches
|
||||
)
|
||||
|
||||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
# -----------------
|
||||
# run train epoch
|
||||
self.run_training_epoch()
|
||||
|
||||
if self.max_steps and self.max_steps <= self.global_step:
|
||||
|
@ -419,9 +401,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# get model
|
||||
model = self.get_model()
|
||||
|
||||
# hook
|
||||
self.train_loop.on_train_epoch_start()
|
||||
|
||||
# modify dataloader if needed (ddp, etc...)
|
||||
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.supporters import Accumulator
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
|
@ -83,10 +83,27 @@ class TrainLoop:
|
|||
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# hook
|
||||
self.trainer.call_hook('on_epoch_start')
|
||||
self.trainer.call_hook('on_train_epoch_start')
|
||||
def on_train_epoch_start(self, epoch):
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
# TODO: move to accelerators
|
||||
if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \
|
||||
and hasattr(self.trainer.train_dataloader, 'sampler') \
|
||||
and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'):
|
||||
self.trainer.train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
self.trainer.current_epoch = epoch
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model())
|
||||
|
||||
# stores accumulated grad fractions per batch
|
||||
self.trainer.batch_loss_value = TensorRunningAccum(
|
||||
window_length=self.trainer.accumulate_grad_batches
|
||||
)
|
||||
|
||||
# bookkeeping
|
||||
self.should_check_val = False
|
||||
|
@ -95,6 +112,10 @@ class TrainLoop:
|
|||
self.early_stopping_accumulator = Accumulator()
|
||||
self.checkpoint_accumulator = Accumulator()
|
||||
|
||||
# hook
|
||||
self.trainer.call_hook('on_epoch_start')
|
||||
self.trainer.call_hook('on_train_epoch_start')
|
||||
|
||||
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
|
||||
# figure out what to track for epoch end
|
||||
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
|
||||
|
|
Loading…
Reference in New Issue