Rename ttp -> strategy (#11312)
This commit is contained in:
parent
33c3490685
commit
5ac129e95a
|
@ -368,9 +368,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
accumulation_done = self._accumulated_batches_reached()
|
||||
# Lightning steps on the final batch
|
||||
is_final_batch = self._num_ready_batches_reached()
|
||||
# but the TTP might not
|
||||
ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
|
||||
return not accumulation_done and ttp_accumulates_on_final_batch
|
||||
# but the strategy might not
|
||||
strategy_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
|
||||
return not accumulation_done and strategy_accumulates_on_final_batch
|
||||
|
||||
@staticmethod
|
||||
def _prepare_outputs_training_batch_end(
|
||||
|
|
|
@ -1504,11 +1504,11 @@ class Trainer(
|
|||
# todo: move this data parallel logic into the data parallel strategy
|
||||
output = accelerator_output if output is None else output
|
||||
|
||||
# call the ttp hook
|
||||
# call the strategy hook
|
||||
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(self.strategy, hook_name):
|
||||
ttp_hook = getattr(self.strategy, hook_name)
|
||||
ttp_output = ttp_hook(*args, **kwargs)
|
||||
output = ttp_output if output is None else output
|
||||
strategy_hook = getattr(self.strategy, hook_name)
|
||||
strategy_output = strategy_hook(*args, **kwargs)
|
||||
output = strategy_output if output is None else output
|
||||
|
||||
if pl_module:
|
||||
# restore current_fx when nested context
|
||||
|
@ -1791,8 +1791,10 @@ class Trainer(
|
|||
rank_zero_deprecation(
|
||||
"`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8.", stacklevel=5
|
||||
)
|
||||
ttp = self.strategy
|
||||
return isinstance(ttp, pl.strategies.TPUSpawnStrategy) and ttp.local_rank == 0 or ttp.is_global_zero
|
||||
strategy = self.strategy
|
||||
return (
|
||||
isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero
|
||||
)
|
||||
|
||||
@property
|
||||
def _distrib_type(self) -> _StrategyType:
|
||||
|
|
|
@ -341,25 +341,25 @@ def test_custom_accelerator(device_count_mock, setup_distributed_mock):
|
|||
class Prec(PrecisionPlugin):
|
||||
pass
|
||||
|
||||
class TrainTypePlugin(SingleDeviceStrategy):
|
||||
class Strat(SingleDeviceStrategy):
|
||||
pass
|
||||
|
||||
ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
|
||||
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
|
||||
strategy = Strat(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec())
|
||||
trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2)
|
||||
assert isinstance(trainer.accelerator, Accel)
|
||||
assert isinstance(trainer.strategy, TrainTypePlugin)
|
||||
assert isinstance(trainer.strategy, Strat)
|
||||
assert isinstance(trainer.precision_plugin, Prec)
|
||||
assert trainer._accelerator_connector.strategy is ttp
|
||||
assert trainer._accelerator_connector.strategy is strategy
|
||||
|
||||
class DistributedPlugin(DDPStrategy):
|
||||
class Strat(DDPStrategy):
|
||||
pass
|
||||
|
||||
ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec())
|
||||
trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2)
|
||||
strategy = Strat(accelerator=Accel(), precision_plugin=Prec())
|
||||
trainer = Trainer(strategy=strategy, fast_dev_run=True, num_processes=2)
|
||||
assert isinstance(trainer.accelerator, Accel)
|
||||
assert isinstance(trainer.strategy, DistributedPlugin)
|
||||
assert isinstance(trainer.strategy, Strat)
|
||||
assert isinstance(trainer.precision_plugin, Prec)
|
||||
assert trainer._accelerator_connector.strategy is ttp
|
||||
assert trainer._accelerator_connector.strategy is strategy
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
|
|
Loading…
Reference in New Issue