Rename ttp -> strategy (#11312)

This commit is contained in:
Carlos Mocholí 2022-01-05 12:12:25 +01:00 committed by GitHub
parent 33c3490685
commit 5ac129e95a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 19 deletions

View File

@ -368,9 +368,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
accumulation_done = self._accumulated_batches_reached()
# Lightning steps on the final batch
is_final_batch = self._num_ready_batches_reached()
# but the TTP might not
ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
return not accumulation_done and ttp_accumulates_on_final_batch
# but the strategy might not
strategy_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
return not accumulation_done and strategy_accumulates_on_final_batch
@staticmethod
def _prepare_outputs_training_batch_end(

View File

@ -1504,11 +1504,11 @@ class Trainer(
# todo: move this data parallel logic into the data parallel strategy
output = accelerator_output if output is None else output
# call the ttp hook
# call the strategy hook
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(self.strategy, hook_name):
ttp_hook = getattr(self.strategy, hook_name)
ttp_output = ttp_hook(*args, **kwargs)
output = ttp_output if output is None else output
strategy_hook = getattr(self.strategy, hook_name)
strategy_output = strategy_hook(*args, **kwargs)
output = strategy_output if output is None else output
if pl_module:
# restore current_fx when nested context
@ -1791,8 +1791,10 @@ class Trainer(
rank_zero_deprecation(
"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8.", stacklevel=5
)
ttp = self.strategy
return isinstance(ttp, pl.strategies.TPUSpawnStrategy) and ttp.local_rank == 0 or ttp.is_global_zero
strategy = self.strategy
return (
isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero
)
@property
def _distrib_type(self) -> _StrategyType:

View File

@ -341,25 +341,25 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock):
class Prec(PrecisionPlugin):
pass
class TrainTypePlugin(SingleDeviceStrategy):
class Strat(SingleDeviceStrategy):
pass
ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
strategy = Strat(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2)
assert isinstance(trainer.accelerator, Accel)
assert isinstance(trainer.strategy, TrainTypePlugin)
assert isinstance(trainer.strategy, Strat)
assert isinstance(trainer.precision_plugin, Prec)
assert trainer._accelerator_connector.strategy is ttp
assert trainer._accelerator_connector.strategy is strategy
class DistributedPlugin(DDPStrategy):
class Strat(DDPStrategy):
pass
ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec())
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
strategy = Strat(accelerator=Accel(), precision_plugin=Prec())
trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2)
assert isinstance(trainer.accelerator, Accel)
assert isinstance(trainer.strategy, DistributedPlugin)
assert isinstance(trainer.strategy, Strat)
assert isinstance(trainer.precision_plugin, Prec)
assert trainer._accelerator_connector.strategy is ttp
assert trainer._accelerator_connector.strategy is strategy
@mock.patch.dict(