ref: inner train loop (intermediate step) 12/n (#3371)

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n

* ref: inner train loop (intermediate step) 12/n
This commit is contained in:
William Falcon 2020-09-06 13:31:00 -04:00 committed by GitHub
parent 8eef97c76f
commit d091fafc12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -87,11 +87,10 @@ class TrainLoop:
model = self.trainer.get_model()
# set seed for distributed sampler (enables shuffling for each epoch)
# TODO: move to accelerators
if (self.trainer.use_ddp or self.trainer.use_horovod or self.trainer.on_tpu) \
and hasattr(self.trainer.train_dataloader, 'sampler') \
and hasattr(self.trainer.train_dataloader.sampler, 'set_epoch'):
try:
self.trainer.train_dataloader.sampler.set_epoch(epoch)
except Exception:
pass
# update training progress in trainer and model
model.current_epoch = epoch