added set_epoch for distributed sampler, fix for #224 (#225)

This commit is contained in:
Ananya Harsh Jha 2019-09-16 10:21:00 -04:00 committed by William Falcon
parent e339799a0a
commit c0f3b6b035
1 changed files with 4 additions and 0 deletions

View File

@ -929,6 +929,10 @@ class Trainer(TrainerIO):
def __train(self):
# run all epochs
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if self.use_ddp:
self.tng_dataloader.sampler.set_epoch(epoch_nb)
# get model
model = self.__get_model()