Merge branch 'master' of https://github.com/williamFalcon/pytorch-lightning
This commit is contained in:
commit
7225e5d6d5
|
@ -43,13 +43,20 @@ trainer = Trainer(test_percent_check=0.1)
|
|||
|
||||
---
|
||||
#### Set validation check frequency within 1 training epoch
|
||||
For large datasets it's often desirable to check validation multiple times within a training loop
|
||||
For large datasets it's often desirable to check validation multiple times within a training loop.
|
||||
Pass in a float to check that often within 1 training epoch.
|
||||
Pass in an int k to check every k training batches. Must use an int if using
|
||||
an IterableDataset.
|
||||
|
||||
``` {.python}
|
||||
# DEFAULT
|
||||
trainer = Trainer(val_check_interval=0.95)
|
||||
|
||||
# check every .25 of an epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
|
||||
trainer = Trainer(val_check_interval=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
|
|
@ -2,6 +2,9 @@ import warnings
|
|||
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -15,6 +18,9 @@ class TrainerDataLoadingMixin(object):
|
|||
def layout_bookeeping(self):
|
||||
|
||||
# determine number of training batches
|
||||
if isinstance(self.get_train_dataloader(), IterableDataset):
|
||||
self.nb_training_batches = float('inf')
|
||||
else:
|
||||
self.nb_training_batches = len(self.get_train_dataloader())
|
||||
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
|
||||
|
||||
|
@ -34,6 +40,11 @@ class TrainerDataLoadingMixin(object):
|
|||
self.nb_test_batches = max(1, self.nb_test_batches)
|
||||
|
||||
# determine when to check validation
|
||||
# if int passed in, val checks that often
|
||||
# otherwise, it checks in [0, 1.0] % range of a training epoch
|
||||
if isinstance(self.val_check_interval, int):
|
||||
self.val_check_batch = self.val_check_interval
|
||||
else:
|
||||
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
|
@ -127,6 +138,16 @@ class TrainerDataLoadingMixin(object):
|
|||
self.get_test_dataloaders()
|
||||
self.get_val_dataloaders()
|
||||
|
||||
# support IterableDataset for train data
|
||||
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
|
||||
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
|
||||
m = '''
|
||||
When using an iterableDataset for train_dataloader,
|
||||
Trainer(val_check_interval) must be an int.
|
||||
An int k specifies checking validation every k training batches
|
||||
'''
|
||||
raise MisconfigurationException('when using ')
|
||||
|
||||
def determine_data_use_amount(self, train_percent_check, val_percent_check,
|
||||
test_percent_check, overfit_pct):
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,12 @@ class TrainerTrainLoopMixin(object):
|
|||
|
||||
# init progress_bar when requested
|
||||
if self.show_progress_bar:
|
||||
self.progress_bar.reset(self.total_batches)
|
||||
nb_iterations = self.total_batches
|
||||
|
||||
# for iterable train loader, the progress bar never ends
|
||||
if self.is_iterable_train_dataloader:
|
||||
nb_iterations = float('inf')
|
||||
self.progress_bar.reset(nb_iterations)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
|
||||
|
|
|
@ -101,7 +101,7 @@ class Trainer(TrainerIOMixin,
|
|||
:param train_percent_check: int. How much of train set to check
|
||||
:param val_percent_check: int. How much of val set to check
|
||||
:param test_percent_check: int. How much of test set to check
|
||||
:param val_check_interval: int. Check val this frequently within a train epoch
|
||||
:param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
|
||||
:param log_save_interval: int. Writes logs to disk this often
|
||||
:param row_log_interval: int. How often to add logging rows
|
||||
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
|
||||
|
@ -160,6 +160,7 @@ class Trainer(TrainerIOMixin,
|
|||
self.get_train_dataloader = None
|
||||
self.get_test_dataloaders = None
|
||||
self.get_val_dataloaders = None
|
||||
self.is_iterable_train_dataloader = False
|
||||
|
||||
# training state
|
||||
self.model = None
|
||||
|
|
Loading…
Reference in New Issue