16 lines
453 B
Python
16 lines
453 B
Python
|
import torch
|
||
|
from typing import Any
|
||
|
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||
|
|
||
|
|
||
|
class Accelerator(object):
|
||
|
|
||
|
def __init__(self, trainer):
|
||
|
self.trainer = trainer
|
||
|
|
||
|
def batch_to_device(self, batch: Any, device: torch.device):
|
||
|
model = self.trainer.get_model()
|
||
|
if model is not None:
|
||
|
return model.transfer_batch_to_device(batch, device)
|
||
|
return move_data_to_device(batch, device)
|