2020-08-24 11:02:06 +00:00
|
|
|
import torch
|
|
|
|
from typing import Any
|
|
|
|
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
|
|
|
|
|
|
|
|
|
|
|
class Accelerator(object):
|
|
|
|
|
|
|
|
def __init__(self, trainer):
|
|
|
|
self.trainer = trainer
|
|
|
|
|
2020-08-26 23:10:24 +00:00
|
|
|
def setup(self, model):
|
2020-08-26 22:43:28 +00:00
|
|
|
pass
|
|
|
|
|
2020-08-26 18:20:38 +00:00
|
|
|
def teardown(self):
|
|
|
|
pass
|
|
|
|
|
2020-08-24 11:02:06 +00:00
|
|
|
def batch_to_device(self, batch: Any, device: torch.device):
|
|
|
|
model = self.trainer.get_model()
|
|
|
|
if model is not None:
|
|
|
|
return model.transfer_batch_to_device(batch, device)
|
|
|
|
return move_data_to_device(batch, device)
|
2020-08-24 21:50:47 +00:00
|
|
|
|
|
|
|
def training_step_end(self, output):
|
|
|
|
return output
|
|
|
|
|
|
|
|
def test_step_end(self, output):
|
|
|
|
return output
|
|
|
|
|
|
|
|
def validation_step_end(self, output):
|
|
|
|
return output
|
2020-08-25 01:53:56 +00:00
|
|
|
|
|
|
|
def process_dataloader(self, dataloader):
|
|
|
|
return dataloader
|