ref: inner train loop (intermediate step) 11/n (#3370)

* ref: inner train loop (intermediate step) 11/n

* ref: inner train loop (intermediate step) 11/n
This commit is contained in:
William Falcon 2020-09-06 12:49:12 -04:00 committed by GitHub
parent 85421466ab
commit 8eef97c76f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 30 deletions

View File

@ -176,10 +176,8 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
@ -279,7 +277,6 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable
on_train_epoch_start: Callable
on_train_epoch_end: Callable
@abstractmethod
@ -351,30 +348,15 @@ class TrainerTrainLoopMixin(ABC):
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_horovod or self.on_tpu) \
and hasattr(self.train_dataloader, 'sampler') \
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)
# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
# hook
self.train_loop.on_train_epoch_start(epoch)
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
# stores accumulated grad fractions per batch
self.batch_loss_value = TensorRunningAccum(
window_length=self.accumulate_grad_batches
)
# -----------------
# RUN TNG EPOCH
# -----------------
# run train epoch
self.run_training_epoch()
if self.max_steps and self.max_steps <= self.global_step:
@ -419,9 +401,6 @@ class TrainerTrainLoopMixin(ABC):
# get model
model = self.get_model()
# hook
self.train_loop.on_train_epoch_start()
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)

View File

@ -3,7 +3,7 @@ import numpy as np
import torch
import torch.distributed as torch_distrib
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.supporters import Accumulator
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.memory import recursive_detach
@ -83,10 +83,27 @@ class TrainLoop:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
def on_train_epoch_start(self):
# hook
self.trainer.call_hook('on_epoch_start')
self.trainer.call_hook('on_train_epoch_start')
def on_train_epoch_start(self, epoch):
model = self.trainer.get_model()
# set seed for distributed sampler (enables shuffling for each epoch)
# TODO: move to accelerators
if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \
and hasattr(self.trainer.train_dataloader, 'sampler') \
and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'):
self.trainer.train_dataloader.sampler.set_epoch(epoch)
# update training progress in trainer and model
model.current_epoch = epoch
self.trainer.current_epoch = epoch
# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model())
# stores accumulated grad fractions per batch
self.trainer.batch_loss_value = TensorRunningAccum(
window_length=self.trainer.accumulate_grad_batches
)
# bookkeeping
self.should_check_val = False
@ -95,6 +112,10 @@ class TrainLoop:
self.early_stopping_accumulator = Accumulator()
self.checkpoint_accumulator = Accumulator()
# hook
self.trainer.call_hook('on_epoch_start')
self.trainer.call_hook('on_train_epoch_start')
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)