fixed error with shorter batch cycles
This commit is contained in:
parent
ed787fb061
commit
c973245ba1
|
@ -122,21 +122,21 @@ class Trainer(TrainerIO):
|
|||
self.tqdm_metrics = {}
|
||||
|
||||
# determine number of training batches
|
||||
nb_tng_batches = self.model.nb_batches(self.tng_dataloader)
|
||||
self.nb_tng_batches = int(nb_tng_batches * self.train_percent_check)
|
||||
self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader)
|
||||
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
|
||||
|
||||
# determine number of validation batches
|
||||
nb_val_batches = self.model.nb_batches(self.val_dataloader)
|
||||
nb_val_batches = int(nb_val_batches * self.val_percent_check)
|
||||
nb_val_batches = max(1, nb_val_batches)
|
||||
self.nb_val_batches = nb_val_batches
|
||||
self.nb_val_batches = self.model.nb_batches(self.val_dataloader)
|
||||
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
|
||||
self.nb_val_batches = max(1, self.nb_val_batches)
|
||||
self.nb_val_batches = self.nb_val_batches
|
||||
|
||||
# determine number of test batches
|
||||
nb_test_batches = self.model.nb_batches(self.test_dataloader)
|
||||
self.nb_test_batches = int(nb_test_batches * self.test_percent_check)
|
||||
self.nb_test_batches = self.model.nb_batches(self.test_dataloader)
|
||||
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
|
||||
|
||||
# determine when to check validation
|
||||
self.val_check_batch = int(nb_tng_batches * self.val_check_interval)
|
||||
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
|
||||
|
||||
def __add_tqdm_metrics(self, metrics):
|
||||
for k, v in metrics.items():
|
||||
|
|
Loading…
Reference in New Issue