2019-11-28 17:48:55 +00:00
|
|
|
"""
|
2020-01-16 12:22:29 +00:00
|
|
|
Validation loop
|
|
|
|
===============
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
The lightning validation loop handles everything except the actual computations of your model.
|
|
|
|
To decide what will happen in your validation loop, define the `validation_step` function.
|
|
|
|
Below are all the things lightning automates for you in the validation loop.
|
|
|
|
|
|
|
|
.. note:: Lightning will run 5 steps of validation in the beginning of training as a sanity
|
|
|
|
check so you don't have to wait until a full epoch to catch possible validation issues.
|
|
|
|
|
|
|
|
Check validation every n epochs
|
|
|
|
-------------------------------
|
|
|
|
|
|
|
|
If you have a small dataset you might want to check validation every n epochs
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
trainer = Trainer(check_val_every_n_epoch=1)
|
|
|
|
|
|
|
|
Set how much of the validation set to check
|
|
|
|
-------------------------------------------
|
|
|
|
|
|
|
|
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag
|
|
|
|
|
|
|
|
val_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
trainer = Trainer(val_percent_check=1.0)
|
|
|
|
|
|
|
|
# check 10% only
|
|
|
|
trainer = Trainer(val_percent_check=0.1)
|
|
|
|
|
|
|
|
Set how much of the test set to check
|
|
|
|
-------------------------------------
|
|
|
|
|
|
|
|
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag
|
|
|
|
|
|
|
|
test_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
trainer = Trainer(test_percent_check=1.0)
|
|
|
|
|
|
|
|
# check 10% only
|
|
|
|
trainer = Trainer(test_percent_check=0.1)
|
|
|
|
|
|
|
|
Set validation check frequency within 1 training epoch
|
|
|
|
------------------------------------------------------
|
|
|
|
|
|
|
|
For large datasets it's often desirable to check validation multiple times within a training loop.
|
|
|
|
Pass in a float to check that often within 1 training epoch.
|
|
|
|
Pass in an int k to check every k training batches. Must use an int if using an IterableDataset.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
trainer = Trainer(val_check_interval=0.95)
|
|
|
|
|
|
|
|
# check every .25 of an epoch
|
|
|
|
trainer = Trainer(val_check_interval=0.25)
|
|
|
|
|
|
|
|
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
|
|
|
|
trainer = Trainer(val_check_interval=100)
|
|
|
|
|
|
|
|
|
|
|
|
Set the number of validation sanity steps
|
|
|
|
-----------------------------------------
|
|
|
|
|
|
|
|
Lightning runs a few steps of validation in the beginning of training.
|
|
|
|
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# DEFAULT
|
2019-12-04 11:57:10 +00:00
|
|
|
trainer = Trainer(num_sanity_val_steps=5)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
|
2019-12-04 11:57:10 +00:00
|
|
|
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
# Testing loop
|
|
|
|
|
|
|
|
To ensure you don't accidentally use test data to guide training decisions Lightning
|
|
|
|
makes running the test set deliberate.
|
|
|
|
|
|
|
|
**test**
|
|
|
|
|
|
|
|
You have two options to run the test set.
|
|
|
|
First case is where you test right after a full training routine.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# run full training
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# run test set
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
|
|
|
|
Second case is where you load a model and run the test set
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MyLightningModule.load_from_metrics(
|
|
|
|
weights_path='/path/to/pytorch_checkpoint.ckpt',
|
|
|
|
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
|
|
|
|
on_gpu=True,
|
|
|
|
map_location=None
|
|
|
|
)
|
|
|
|
|
|
|
|
# init trainer with whatever options
|
|
|
|
trainer = Trainer(...)
|
|
|
|
|
|
|
|
# test (pass in the model)
|
|
|
|
trainer.test(model)
|
|
|
|
|
|
|
|
In this second case, the options you pass to trainer will be used when running
|
|
|
|
the test set (ie: 16-bit, dp, ddp, etc...)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2020-01-20 19:50:31 +00:00
|
|
|
import sys
|
2020-03-20 19:51:14 +00:00
|
|
|
import warnings
|
2019-12-04 15:57:32 +00:00
|
|
|
from abc import ABC, abstractmethod
|
2020-03-24 18:52:57 +00:00
|
|
|
from pprint import pprint
|
2020-03-12 16:41:37 +00:00
|
|
|
from typing import Callable
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
import torch
|
2020-02-27 21:21:14 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2020-01-26 15:19:09 +00:00
|
|
|
from tqdm.auto import tqdm
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2020-02-28 23:48:07 +00:00
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2020-03-20 19:51:14 +00:00
|
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
2020-03-31 12:57:48 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
try:
|
|
|
|
import torch_xla.distributed.parallel_loader as xla_pl
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
except ImportError:
|
|
|
|
XLA_AVAILABLE = False
|
2020-02-27 21:21:14 +00:00
|
|
|
else:
|
|
|
|
XLA_AVAILABLE = True
|
2020-02-17 21:01:20 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
class TrainerEvaluationLoopMixin(ABC):
|
|
|
|
|
2020-02-27 21:21:14 +00:00
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
|
|
# the proper values/initialisation should be done in child class
|
|
|
|
test_progress_bar: ...
|
|
|
|
val_progress_bar: ...
|
|
|
|
main_progress_bar: ...
|
|
|
|
use_ddp: bool
|
|
|
|
use_dp: bool
|
|
|
|
use_ddp2: bool
|
|
|
|
single_gpu: bool
|
|
|
|
data_parallel_device_ids: ...
|
|
|
|
model: LightningModule
|
|
|
|
num_test_batches: int
|
|
|
|
num_val_batches: int
|
|
|
|
fast_dev_run: ...
|
|
|
|
process_position: ...
|
|
|
|
process_output: ...
|
|
|
|
training_tqdm_dict: ...
|
|
|
|
proc_rank: int
|
|
|
|
current_epoch: int
|
|
|
|
callback_metrics: ...
|
|
|
|
test_dataloaders: DataLoader
|
|
|
|
val_dataloaders: DataLoader
|
|
|
|
use_tpu: bool
|
|
|
|
reload_dataloaders_every_epoch: ...
|
|
|
|
progress_bar_refresh_rate: ...
|
|
|
|
|
|
|
|
# Callback system
|
|
|
|
on_validation_start: Callable
|
|
|
|
on_validation_end: Callable
|
|
|
|
on_test_start: Callable
|
|
|
|
on_test_end: Callable
|
2020-02-26 04:17:27 +00:00
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def copy_trainer_model_properties(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-12-04 15:57:32 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_model(self):
|
2020-02-27 21:21:14 +00:00
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-12-04 15:57:32 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def is_overriden(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-12-04 15:57:32 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def transfer_batch_to_tpu(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2020-02-17 21:01:20 +00:00
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def transfer_batch_to_gpu(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-12-04 15:57:32 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def add_tqdm_metrics(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-12-04 15:57:32 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def log_metrics(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-02-25 03:23:25 +00:00
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def reset_test_dataloader(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2020-02-25 03:23:25 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
2020-02-27 21:21:14 +00:00
|
|
|
def reset_val_dataloader(self, *args):
|
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2020-02-25 03:23:25 +00:00
|
|
|
|
2020-03-30 16:14:27 +00:00
|
|
|
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
|
2019-12-04 15:57:32 +00:00
|
|
|
"""Run evaluation code.
|
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
Args:
|
|
|
|
model: PT model
|
|
|
|
dataloaders: list of PT dataloaders
|
|
|
|
max_batches: Scalar
|
|
|
|
test_mode:
|
2019-10-22 01:16:51 +00:00
|
|
|
"""
|
|
|
|
# enable eval mode
|
|
|
|
model.zero_grad()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
# copy properties for forward overrides
|
|
|
|
self.copy_trainer_model_properties(model)
|
|
|
|
|
|
|
|
# disable gradients to save memory
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
# bookkeeping
|
|
|
|
outputs = []
|
|
|
|
|
2020-02-07 03:01:21 +00:00
|
|
|
# run validation
|
2019-10-22 01:16:51 +00:00
|
|
|
for dataloader_idx, dataloader in enumerate(dataloaders):
|
|
|
|
dl_outputs = []
|
2020-02-17 21:01:20 +00:00
|
|
|
|
|
|
|
# on TPU we have to wrap it under the ParallelLoader
|
|
|
|
if self.use_tpu:
|
|
|
|
device = xm.xla_device()
|
|
|
|
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
|
|
|
dataloader = dataloader.per_device_loader(device)
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
2020-03-19 13:14:29 +00:00
|
|
|
if batch is None:
|
2019-10-22 01:16:51 +00:00
|
|
|
continue
|
|
|
|
|
|
|
|
# stop short when on fast_dev_run (sets max_batch=1)
|
|
|
|
if batch_idx >= max_batches:
|
|
|
|
break
|
|
|
|
|
|
|
|
# -----------------
|
|
|
|
# RUN EVALUATION STEP
|
|
|
|
# -----------------
|
2020-02-25 18:06:24 +00:00
|
|
|
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# on dp / ddp2 might still want to do something with the batch parts
|
|
|
|
if test_mode:
|
|
|
|
if self.is_overriden('test_step_end'):
|
|
|
|
model_ref = self.get_model()
|
|
|
|
with self.profiler.profile('test_step_end'):
|
|
|
|
output = model_ref.test_step_end(output)
|
|
|
|
else:
|
|
|
|
if self.is_overriden('validation_step_end'):
|
|
|
|
model_ref = self.get_model()
|
|
|
|
with self.profiler.profile('validation_step_end'):
|
|
|
|
output = model_ref.validation_step_end(output)
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
# track outputs for collation
|
|
|
|
dl_outputs.append(output)
|
|
|
|
|
|
|
|
# batch done
|
2020-04-02 22:53:00 +00:00
|
|
|
if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode:
|
2020-02-25 03:23:25 +00:00
|
|
|
self.test_progress_bar.update(self.progress_bar_refresh_rate)
|
|
|
|
else:
|
|
|
|
self.val_progress_bar.update(self.progress_bar_refresh_rate)
|
|
|
|
self.main_progress_bar.update(self.progress_bar_refresh_rate)
|
2019-10-22 01:16:51 +00:00
|
|
|
outputs.append(dl_outputs)
|
|
|
|
|
|
|
|
eval_results = {}
|
|
|
|
|
|
|
|
# with a single dataloader don't pass an array
|
|
|
|
if len(dataloaders) == 1:
|
|
|
|
outputs = outputs[0]
|
|
|
|
|
|
|
|
# give model a chance to do something with the outputs (and method defined)
|
2020-03-20 19:51:14 +00:00
|
|
|
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
|
|
|
|
model = model.module
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-03 13:25:32 +00:00
|
|
|
if test_mode:
|
|
|
|
if self.is_overriden('test_end', model=model):
|
|
|
|
# TODO: remove in v1.0.0
|
|
|
|
eval_results = model.test_end(outputs)
|
|
|
|
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
|
|
|
' Use `test_epoch_end` instead.', DeprecationWarning)
|
|
|
|
|
|
|
|
elif self.is_overriden('test_epoch_end', model=model):
|
|
|
|
eval_results = model.test_epoch_end(outputs)
|
|
|
|
|
|
|
|
else:
|
|
|
|
if self.is_overriden('validation_end', model=model):
|
|
|
|
# TODO: remove in v1.0.0
|
|
|
|
eval_results = model.validation_end(outputs)
|
|
|
|
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
|
|
|
' Use `validation_epoch_end` instead.', DeprecationWarning)
|
|
|
|
|
|
|
|
elif self.is_overriden('validation_epoch_end', model=model):
|
|
|
|
eval_results = model.validation_epoch_end(outputs)
|
2019-10-22 01:16:51 +00:00
|
|
|
|
|
|
|
# enable train mode again
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
# enable gradients to save memory
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
|
return eval_results
|
|
|
|
|
2020-02-25 18:06:24 +00:00
|
|
|
def run_evaluation(self, test_mode: bool = False):
|
2019-10-22 01:16:51 +00:00
|
|
|
# when testing make sure user defined a test step
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode and not self.is_overriden('test_step'):
|
2020-03-30 22:37:02 +00:00
|
|
|
raise MisconfigurationException(
|
|
|
|
"You called `.test()` without defining model's `.test_step()`."
|
|
|
|
" Please define and try again")
|
2020-01-14 03:31:15 +00:00
|
|
|
|
2020-02-26 04:17:27 +00:00
|
|
|
# Validation/Test begin callbacks
|
|
|
|
if test_mode:
|
|
|
|
self.on_test_start()
|
|
|
|
else:
|
|
|
|
self.on_validation_start()
|
|
|
|
|
2020-01-14 03:31:15 +00:00
|
|
|
# hook
|
|
|
|
model = self.get_model()
|
|
|
|
model.on_pre_performance_check()
|
|
|
|
|
|
|
|
# select dataloaders
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode:
|
2020-04-02 09:41:56 +00:00
|
|
|
if self.test_dataloaders is None:
|
2020-02-25 03:23:25 +00:00
|
|
|
self.reset_test_dataloader(model)
|
|
|
|
|
|
|
|
dataloaders = self.test_dataloaders
|
2020-01-14 03:31:15 +00:00
|
|
|
max_batches = self.num_test_batches
|
|
|
|
else:
|
|
|
|
# val
|
2020-04-02 09:41:56 +00:00
|
|
|
if self.val_dataloaders is None:
|
2020-02-25 03:23:25 +00:00
|
|
|
self.reset_val_dataloader(model)
|
|
|
|
|
|
|
|
dataloaders = self.val_dataloaders
|
2020-01-14 03:31:15 +00:00
|
|
|
max_batches = self.num_val_batches
|
|
|
|
|
|
|
|
# cap max batches to 1 when using fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
|
|
|
max_batches = 1
|
|
|
|
|
|
|
|
# init validation or test progress bar
|
|
|
|
# main progress bar will already be closed when testing so initial position is free
|
2020-02-25 18:06:24 +00:00
|
|
|
position = 2 * self.process_position + (not test_mode)
|
|
|
|
desc = 'Testing' if test_mode else 'Validating'
|
2020-03-12 16:46:02 +00:00
|
|
|
total = max_batches if max_batches != float('inf') else None
|
|
|
|
pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position,
|
2020-04-02 22:53:00 +00:00
|
|
|
disable=not self.progress_bar_refresh_rate, dynamic_ncols=True, file=sys.stdout)
|
2020-02-25 18:06:24 +00:00
|
|
|
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)
|
2020-01-14 03:31:15 +00:00
|
|
|
|
|
|
|
# run evaluation
|
2020-03-30 16:14:27 +00:00
|
|
|
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
|
2020-01-14 03:31:15 +00:00
|
|
|
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
|
|
|
|
eval_results)
|
|
|
|
|
|
|
|
# add metrics to prog bar
|
|
|
|
self.add_tqdm_metrics(prog_bar_metrics)
|
2020-03-02 22:12:22 +00:00
|
|
|
|
|
|
|
# log results of test
|
|
|
|
if test_mode:
|
2020-03-03 02:50:38 +00:00
|
|
|
if self.proc_rank == 0:
|
|
|
|
print('-' * 100)
|
|
|
|
print('TEST RESULTS')
|
2020-03-24 18:52:57 +00:00
|
|
|
pprint(prog_bar_metrics)
|
2020-03-03 02:50:38 +00:00
|
|
|
print('-' * 100)
|
2020-01-14 03:31:15 +00:00
|
|
|
|
|
|
|
# log metrics
|
|
|
|
self.log_metrics(log_metrics, {})
|
|
|
|
|
|
|
|
# track metrics for callbacks
|
|
|
|
self.callback_metrics.update(callback_metrics)
|
|
|
|
|
|
|
|
# hook
|
|
|
|
model.on_post_performance_check()
|
|
|
|
|
|
|
|
# add model specific metrics
|
2020-02-25 18:06:24 +00:00
|
|
|
if not test_mode:
|
2020-02-05 11:24:43 +00:00
|
|
|
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
|
2020-01-14 03:31:15 +00:00
|
|
|
|
|
|
|
# close progress bar
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode:
|
2020-01-14 03:31:15 +00:00
|
|
|
self.test_progress_bar.close()
|
|
|
|
else:
|
|
|
|
self.val_progress_bar.close()
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-04-02 09:41:56 +00:00
|
|
|
# eventual dataset reloading
|
|
|
|
if test_mode:
|
|
|
|
if self.reload_dataloaders_every_epoch:
|
|
|
|
self.reset_test_dataloader(model)
|
|
|
|
else:
|
|
|
|
# val
|
|
|
|
if self.reload_dataloaders_every_epoch:
|
|
|
|
self.reset_val_dataloader(model)
|
|
|
|
|
2020-02-26 04:17:27 +00:00
|
|
|
# Validation/Test end callbacks
|
|
|
|
if test_mode:
|
|
|
|
self.on_test_end()
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-02-25 18:06:24 +00:00
|
|
|
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
|
2019-10-22 01:16:51 +00:00
|
|
|
# make dataloader_idx arg in validation_step optional
|
|
|
|
args = [batch, batch_idx]
|
|
|
|
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode and len(self.test_dataloaders) > 1:
|
2019-10-22 01:16:51 +00:00
|
|
|
args.append(dataloader_idx)
|
|
|
|
|
2020-02-25 18:06:24 +00:00
|
|
|
elif not test_mode and len(self.val_dataloaders) > 1:
|
2019-10-22 01:16:51 +00:00
|
|
|
args.append(dataloader_idx)
|
|
|
|
|
|
|
|
# handle DP, DDP forward
|
|
|
|
if self.use_ddp or self.use_dp or self.use_ddp2:
|
|
|
|
output = model(*args)
|
|
|
|
return output
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# single GPU data transfer
|
2019-10-22 01:16:51 +00:00
|
|
|
if self.single_gpu:
|
|
|
|
# for single GPU put inputs on gpu manually
|
|
|
|
root_gpu = 0
|
2019-12-04 15:57:32 +00:00
|
|
|
if isinstance(self.data_parallel_device_ids, list):
|
2019-10-22 01:16:51 +00:00
|
|
|
root_gpu = self.data_parallel_device_ids[0]
|
|
|
|
batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
|
|
|
args[0] = batch
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# TPU data transfer
|
2020-02-17 21:01:20 +00:00
|
|
|
if self.use_tpu:
|
|
|
|
batch = self.transfer_batch_to_tpu(batch)
|
|
|
|
args[0] = batch
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# CPU, TPU or gpu step
|
2020-02-25 18:06:24 +00:00
|
|
|
if test_mode:
|
2019-10-22 01:16:51 +00:00
|
|
|
output = model.test_step(*args)
|
|
|
|
else:
|
|
|
|
output = model.validation_step(*args)
|
|
|
|
|
|
|
|
return output
|