ref: inner train loop (intermediate step) 12/n (#3371)
* ref: inner train loop (intermediate step) 12/n * ref: inner train loop (intermediate step) 12/n * ref: inner train loop (intermediate step) 12/n
This commit is contained in:
parent
8eef97c76f
commit
d091fafc12
|
@ -87,11 +87,10 @@ class TrainLoop:
|
|||
model = self.trainer.get_model()
|
||||
|
||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||
# TODO: move to accelerators
|
||||
if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \
|
||||
and hasattr(self.trainer.train_dataloader, 'sampler') \
|
||||
and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'):
|
||||
try:
|
||||
self.trainer.train_dataloader.sampler.set_epoch(epoch)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
|
|
Loading…
Reference in New Issue