rename accelerator_backend -> accelerator (#6034)

* rename accelerator backend

* rename new additions from master

* add proper deprecation

* pep8

* warning match

* add missing warning type
This commit is contained in:
Adrian Wälchli 2021-02-18 16:54:12 +01:00 committed by GitHub
parent 02ac4b0b6a
commit 6cc1a06078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 101 additions and 92 deletions

View File

@ -240,6 +240,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035))
- Deprecated Trainer attribute `accelerator_backend` in favor of `accelerator` ([#6034](https://github.com/PyTorchLightning/pytorch-lightning/pull/6034))
### Removed
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))

View File

@ -222,6 +222,6 @@ if __name__ == "__main__":
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)
if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.training_type_plugin.exit_rpc_process()

View File

@ -41,7 +41,6 @@ from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
@ -448,11 +447,7 @@ class LightningModule(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available
all_gather = self.trainer.accelerator.all_gather
data = convert_to_tensors(data, device=self.device)
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)

View File

@ -132,7 +132,7 @@ class LightningOptimizer:
model = trainer.lightning_module
with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
if self._trainer.train_loop.automatic_optimization:
trainer.train_loop.on_before_zero_grad(optimizer)

View File

@ -185,7 +185,7 @@ class DeepSpeedPlugin(DDPPlugin):
self._format_config()
self._config_initialized = True
precision = self.lightning_module.trainer.accelerator_backend.precision
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
if self.lightning_module.trainer.training:

View File

@ -90,7 +90,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
trainer.progress_bar_callback.disable()
self.model_to_device()
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)
# replace trainer save_checkpoint to use `xm.save`

View File

@ -219,8 +219,7 @@ class CheckpointConnector:
model.on_hpc_save(checkpoint)
if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
checkpoint = self.trainer.accelerator.on_save(checkpoint)
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
@ -286,7 +285,7 @@ class CheckpointConnector:
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint['optimizer_states'] = optimizer_states

View File

@ -51,7 +51,7 @@ class TrainerDataLoadingMixin(ABC):
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
accelerator_backend: Accelerator
accelerator: Accelerator
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
@ -398,8 +398,7 @@ class TrainerDataLoadingMixin(ABC):
dataloader = dataloader_fx()
dataloader = self._flatten_dl_only(dataloader)
if self.accelerator_backend is not None:
self.accelerator_backend.barrier('get_dataloaders')
self.accelerator.barrier('get_dataloaders')
return dataloader
def _flatten_dl_only(self, dataloaders):

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
@ -133,10 +134,19 @@ class DeprecatedDistDeviceAttributes:
self.accelerator_connector._device_type = DeviceType.GPU
class DeprecatedModelAttributes:
class DeprecatedTrainerAttributes:
accelerator: Accelerator
lightning_module = LightningModule
@property
def accelerator_backend(self) -> Accelerator:
rank_zero_warn(
"The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`"
" since 1.2 and will be removed in v1.4.", DeprecationWarning
)
return self.accelerator
def get_model(self) -> LightningModule:
rank_zero_warn(
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"

View File

@ -157,11 +157,11 @@ class EvaluationLoop(object):
if self.testing:
model_ref._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)
output = self.trainer.accelerator.test_step(args)
else:
model_ref._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator_backend.validation_step(args)
output = self.trainer.accelerator.validation_step(args)
# capture any logged information
self.trainer.logger_connector.cache_logged_metrics()

View File

@ -74,7 +74,7 @@ class PredictLoop(object):
model_ref = self.trainer.lightning_module
model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator_backend.predict(args)
predictions = self.trainer.accelerator.predict(args)
self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx

View File

@ -62,11 +62,6 @@ class TrainerProperties(ABC):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator
@property
def accelerator_backend(self) -> Accelerator:
# for backward compatibility
return self.accelerator
@property
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
@ -138,7 +133,7 @@ class TrainerProperties(ABC):
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
dirpath = self.accelerator_backend.broadcast(dirpath)
dirpath = self.accelerator.broadcast(dirpath)
return dirpath
@property
@ -360,7 +355,7 @@ class TrainerProperties(ABC):
@property
def lightning_module(self) -> LightningModule:
return self.accelerator_backend.lightning_module
return self.accelerator.lightning_module
@property
def optimizers(self) -> Optional[List[Optimizer]]:

View File

@ -45,7 +45,7 @@ from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConn
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedTrainerAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@ -80,7 +80,7 @@ class Trainer(
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
DeprecatedModelAttributes,
DeprecatedTrainerAttributes,
):
@overwrite_by_env_vars
@ -470,7 +470,7 @@ class Trainer(
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator_backend.setup(self, model)
self.accelerator.setup(self, model)
self.setup_trainer(model)
# ----------------------------
@ -533,24 +533,24 @@ class Trainer(
self._set_running_stage(None, model)
return self.accelerator_backend.results or 1
return self.accelerator.results or 1
def pre_dispatch(self):
self.accelerator_backend.pre_dispatch()
self.accelerator.pre_dispatch()
def post_dispatch(self):
self.accelerator_backend.post_dispatch()
self.accelerator_backend.teardown()
self.accelerator.post_dispatch()
self.accelerator.teardown()
def dispatch(self):
if self.testing:
self.accelerator_backend.start_testing(self)
self.accelerator.start_testing(self)
elif self.predicting:
self.accelerator_backend.start_predicting(self)
self.accelerator.start_predicting(self)
else:
self.accelerator_backend.start_training(self)
self.accelerator.start_training(self)
def train_or_test_or_predict(self):
if self.testing:
@ -949,7 +949,7 @@ class Trainer(
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator_backend.barrier()
self.accelerator.barrier()
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
@ -1109,8 +1109,8 @@ class Trainer(
# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
elif hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
output = accelerator_hook(*args, **kwargs)
if not skip:

View File

@ -290,8 +290,8 @@ class TrainLoop:
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.accelerator_backend.post_training_step()
training_step_output = self.trainer.accelerator.training_step(args)
self.trainer.accelerator.post_training_step()
self.trainer.logger_connector.cache_logged_metrics()
@ -438,14 +438,14 @@ class TrainLoop:
self.trainer.call_hook('on_before_zero_grad', optimizer)
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
def track_and_norm_grad(self, optimizer):
# track gradient norms
grad_norm_dic = self._track_gradient_norm()
# clip gradients
self.trainer.accelerator_backend.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self._cur_grad_norm_dict = grad_norm_dic
def _track_gradient_norm(self):
@ -769,9 +769,9 @@ class TrainLoop:
# backward can be called manually in the training loop
if isinstance(result, torch.Tensor):
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss = self.trainer.accelerator.backward(
result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

View File

@ -33,7 +33,7 @@ def test_accelerator_choice_cpu(tmpdir):
default_root_dir=tmpdir,
fast_dev_run=True,
)
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, SingleDevicePlugin)
@ -42,7 +42,7 @@ def test_accelerator_choice_ddp_cpu(tmpdir):
fast_dev_run=True,
accelerator='ddp_cpu',
)
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
@ -56,7 +56,7 @@ def test_accelerator_choice_ddp(cuda_available_mock, device_count_mock):
accelerator='ddp',
gpus=1,
)
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
@ -70,7 +70,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
accelerator='ddp_spawn',
gpus=1,
)
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
@ -92,7 +92,7 @@ def test_accelerator_choice_ddp_slurm():
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 10
@ -130,7 +130,7 @@ def test_accelerator_choice_ddp2_slurm(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp2
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 10
@ -158,7 +158,7 @@ def test_accelerator_choice_ddp_te(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 10
@ -186,7 +186,7 @@ def test_accelerator_choice_ddp2_te(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp2
assert isinstance(trainer.accelerator_backend, GPUAccelerator)
assert isinstance(trainer.accelerator, GPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDP2Plugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 10
@ -217,7 +217,7 @@ def test_accelerator_choice_ddp_cpu_te(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment)
assert trainer.training_type_plugin.cluster_environment.local_rank() == 10
@ -253,7 +253,7 @@ def test_accelerator_choice_ddp_cpu_slurm(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert trainer.accelerator_connector.is_slurm_managing_tasks
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
assert trainer.training_type_plugin.task_idx == 0
@ -295,7 +295,7 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock):
def on_fit_start(self, trainer, pl_module):
assert trainer.use_ddp
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert isinstance(trainer.training_type_plugin.cluster_environment, CustomCluster)
raise SystemExit()
@ -343,7 +343,7 @@ def test_custom_accelerator(device_count_mock):
fast_dev_run=True,
num_processes=2,
)
assert isinstance(trainer.accelerator_backend, Accel)
assert isinstance(trainer.accelerator, Accel)
assert isinstance(trainer.training_type_plugin, TrainTypePlugin)
assert isinstance(trainer.precision_plugin, Prec)
@ -363,7 +363,7 @@ def test_dist_backend_accelerator_mapping(device_count_mock):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.training_type_plugin, DDPPlugin)
assert trainer.training_type_plugin.task_idx == 0
raise SystemExit()

View File

@ -473,7 +473,7 @@ def test_dm_apply_batch_transfer_handler(get_module_mock):
model.transfer_batch_to_device = dm.transfer_batch_to_device
model.on_after_batch_transfer = dm.on_after_batch_transfer
batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device)
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
assert dm.on_before_batch_transfer_hook_rank == 0
assert dm.transfer_batch_to_device_hook_rank == 1

View File

@ -30,6 +30,13 @@ from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
def test_v1_4_0_deprecated_trainer_attributes():
with pytest.deprecated_call(match="will be removed in v1.4."):
trainer = Trainer()
_ = trainer.accelerator_backend
assert trainer.accelerator == trainer.accelerator_backend
def test_v1_4_0_deprecated_trainer_methods():
with pytest.deprecated_call(match='will be removed in v1.4'):
trainer = Trainer()

View File

@ -219,35 +219,35 @@ def test_single_gpu_batch_parse():
# non-transferrable types
primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}]
for batch in primitive_objects:
data = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
data = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert data == batch
# batch is just a tensor
batch = torch.rand(2, 3)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor'
# tensor list
batch = [torch.rand(2, 3), torch.rand(2, 3)]
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0].device.index == 0 and batch[0].type() == 'torch.cuda.FloatTensor'
assert batch[1].device.index == 0 and batch[1].type() == 'torch.cuda.FloatTensor'
# tensor list of lists
batch = [[torch.rand(2, 3), torch.rand(2, 3)]]
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'
assert batch[0][1].device.index == 0 and batch[0][1].type() == 'torch.cuda.FloatTensor'
# tensor dict
batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}]
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0]['a'].device.index == 0 and batch[0]['a'].type() == 'torch.cuda.FloatTensor'
assert batch[0]['b'].device.index == 0 and batch[0]['b'].type() == 'torch.cuda.FloatTensor'
# tuple of tensor list and list of tensor dict
batch = ([torch.rand(2, 3) for _ in range(2)], [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)} for _ in range(2)])
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'
assert batch[1][0]['a'].device.index == 0
@ -259,7 +259,7 @@ def test_single_gpu_batch_parse():
# namedtuple of tensor
BatchType = namedtuple('BatchType', ['a', 'b'])
batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'
@ -273,7 +273,7 @@ def test_single_gpu_batch_parse():
self.a = self.a.to(*args, **kwargs)
return self
batch = trainer.accelerator_backend.batch_to_device(CustomBatchType(), torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device('cuda:0'))
assert batch.a.type() == 'torch.cuda.FloatTensor'
# torchtext.data.Batch
@ -297,7 +297,7 @@ def test_single_gpu_batch_parse():
label_field.build_vocab(dataset)
batch = Batch(data=examples, dataset=dataset)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
assert batch.text.type() == 'torch.cuda.LongTensor'
assert batch.label.type() == 'torch.cuda.LongTensor'
@ -310,7 +310,7 @@ def test_non_blocking():
batch = torch.zeros(2, 3)
with patch.object(batch, 'to', wraps=batch.to) as mocked:
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)
class BatchObject(object):
@ -320,5 +320,5 @@ def test_non_blocking():
batch = BatchObject()
with patch.object(batch, 'to', wraps=batch.to) as mocked:
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
mocked.assert_called_with(torch.device('cuda', 0))

View File

@ -187,7 +187,7 @@ def test_apply_batch_transfer_handler(model_getter_mock):
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
model_getter_mock.return_value = model
batch_gpu = trainer.accelerator_backend.batch_to_device(batch, expected_device)
batch_gpu = trainer.accelerator.batch_to_device(batch, expected_device)
assert model.on_before_batch_transfer_hook_rank == 0
assert model.transfer_batch_to_device_hook_rank == 1

View File

@ -303,7 +303,7 @@ def test_accuracy_metric_horovod():
accelerator='horovod',
)
assert isinstance(trainer.accelerator_backend, CPUAccelerator)
assert isinstance(trainer.accelerator, CPUAccelerator)
# TODO: test that we selected the correct training_type_plugin based on horovod flags
metric = Accuracy(

View File

@ -271,7 +271,7 @@ def test_broadcast_on_tpu():
def test_broadcast(rank):
trainer = Trainer(tpu_cores=8)
assert isinstance(trainer.accelerator_backend, TPUAccelerator)
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
obj = ("ver_0.5", "logger_name", rank)
result = trainer.training_type_plugin.broadcast(obj)

View File

@ -46,8 +46,8 @@ def test_deepspeed_plugin_string(tmpdir):
plugins='deepspeed',
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')]
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator.training_type_plugin.parallel_devices == [torch.device('cpu')]
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@ -62,8 +62,8 @@ def test_deepspeed_plugin(tmpdir):
plugins=[DeepSpeedPlugin()],
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')]
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
assert trainer.accelerator.training_type_plugin.parallel_devices == [torch.device('cpu')]
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
@ -82,7 +82,7 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config):
plugins='deepspeed',
)
plugin = trainer.accelerator_backend.training_type_plugin
plugin = trainer.accelerator.training_type_plugin
assert isinstance(plugin, DeepSpeedPlugin)
assert plugin.parallel_devices == [torch.device('cpu')]
assert plugin.config == deepspeed_config
@ -106,9 +106,9 @@ def test_deepspeed_precision_choice(amp_backend, tmpdir):
fast_dev_run=True, default_root_dir=tmpdir, plugins='deepspeed', amp_backend=amp_backend, precision=16
)
assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin)
assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin)
assert trainer.accelerator_backend.precision_plugin.precision == 16
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin)
assert trainer.accelerator.precision_plugin.precision == 16
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")

View File

@ -50,9 +50,9 @@ def test_rpc_sequential_plugin_manual(tmpdir, args=None):
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()
trainer.accelerator.training_type_plugin.exit_rpc_process()
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@ -104,9 +104,9 @@ def test_rpc_sequential_plugin_automatic(tmpdir, args=None):
if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0:
assert len(trainer.dev_debugger.pbar_added_metrics) > 0
if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()
trainer.accelerator.training_type_plugin.exit_rpc_process()
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@ -132,9 +132,9 @@ def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None):
):
trainer.fit(model)
if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.accelerator_backend.training_type_plugin.exit_rpc_process()
trainer.accelerator.training_type_plugin.exit_rpc_process()
class SequentialModelRPCManual(LightningModule):

View File

@ -23,9 +23,9 @@ def test_sharded_ddp_choice(tmpdir, accelerator):
def on_fit_start(self, trainer, pl_module):
if accelerator == 'ddp_sharded':
assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPShardedPlugin)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
elif accelerator == 'ddp_sharded_spawn':
assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPSpawnShardedPlugin)
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
raise SystemExit()
model = BoringModel()
@ -71,9 +71,9 @@ def test_ddp_choice_sharded_amp(tmpdir, accelerator):
def on_fit_start(self, trainer, pl_module):
if accelerator == 'ddp_sharded':
assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPShardedPlugin)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
elif accelerator == 'ddp_sharded_spawn':
assert isinstance(trainer.accelerator_backend.training_type_plugin, DDPSpawnShardedPlugin)
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
raise SystemExit()
model = BoringModel()