diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 2b547f7b72..03bd34a4a0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -368,9 +368,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): accumulation_done = self._accumulated_batches_reached() # Lightning steps on the final batch is_final_batch = self._num_ready_batches_reached() - # but the TTP might not - ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch - return not accumulation_done and ttp_accumulates_on_final_batch + # but the strategy might not + strategy_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch + return not accumulation_done and strategy_accumulates_on_final_batch @staticmethod def _prepare_outputs_training_batch_end( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d89ad75411..97020b8f48 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1504,11 +1504,11 @@ class Trainer( # todo: move this data parallel logic into the data parallel strategy output = accelerator_output if output is None else output - # call the ttp hook + # call the strategy hook if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(self.strategy, hook_name): - ttp_hook = getattr(self.strategy, hook_name) - ttp_output = ttp_hook(*args, **kwargs) - output = ttp_output if output is None else output + strategy_hook = getattr(self.strategy, hook_name) + strategy_output = strategy_hook(*args, **kwargs) + output = strategy_output if output is None else output if pl_module: # restore current_fx when nested context @@ -1791,8 +1791,10 @@ class Trainer( rank_zero_deprecation( "`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8.", stacklevel=5 ) - ttp = self.strategy - return isinstance(ttp, pl.strategies.TPUSpawnStrategy) and ttp.local_rank == 0 or ttp.is_global_zero + strategy = self.strategy + return ( + isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero + ) @property def _distrib_type(self) -> _StrategyType: diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index bc8807c3da..9b47e9c73b 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -341,25 +341,25 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Prec(PrecisionPlugin): pass - class TrainTypePlugin(SingleDeviceStrategy): + class Strat(SingleDeviceStrategy): pass - ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec()) - trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) + strategy = Strat(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec()) + trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) - assert isinstance(trainer.strategy, TrainTypePlugin) + assert isinstance(trainer.strategy, Strat) assert isinstance(trainer.precision_plugin, Prec) - assert trainer._accelerator_connector.strategy is ttp + assert trainer._accelerator_connector.strategy is strategy - class DistributedPlugin(DDPStrategy): + class Strat(DDPStrategy): pass - ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec()) - trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) + strategy = Strat(accelerator=Accel(), precision_plugin=Prec()) + trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) - assert isinstance(trainer.strategy, DistributedPlugin) + assert isinstance(trainer.strategy, Strat) assert isinstance(trainer.precision_plugin, Prec) - assert trainer._accelerator_connector.strategy is ttp + assert trainer._accelerator_connector.strategy is strategy @mock.patch.dict(