Remove `Trainer(move_metrics_to_cpu=True)` (#16358)

This commit is contained in:
Carlos Mocholí 2023-01-16 16:32:29 +01:00 committed by Luca Antiga
parent 14933592f4
commit caa82d8600
9 changed files with 14 additions and 49 deletions

View File

@ -68,7 +68,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `FitLoop.split_idx` property
* Removed the `LoggerConnector.on_train_split_start` method
- Removed the `Trainer(move_metrics_to_cpu=True)` argument ([#16358](https://github.com/Lightning-AI/lightning/pull/16358))
- Removed the `LightningModule.precision` attribute ([#16203](https://github.com/Lightning-AI/lightning/pull/16203))
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))

View File

@ -154,11 +154,6 @@ class EvaluationEpochLoop(Loop):
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
self._outputs.append(output)
if self.trainer.move_metrics_to_cpu:
# the evaluation step output is not moved as they are not considered "metrics"
assert self.trainer._results is not None
self.trainer._results.cpu()
if not self.batch_progress.is_last_batch:
# if fault tolerant is enabled and process has been notified, exit.
self.trainer._exit_gracefully_on_signal()

View File

@ -116,12 +116,6 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
result = self.output_result_cls.from_training_step_output(training_step_output)
if self.trainer.move_metrics_to_cpu:
# training step output does not get moved because it is not considered a "metric"
# the user might need them on the correct device for an operation in `training_epoch_end`
assert self.trainer._results is not None
self.trainer._results.cpu()
self._done = True
self._output = result.asdict()

View File

@ -397,11 +397,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
training_step_output, self.trainer.accumulate_grad_batches
)
if self.trainer.move_metrics_to_cpu:
# training step output does not get moved because it is not considered a "metric"
assert self.trainer._results is not None
self.trainer._results.cpu()
return result
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int) -> OrderedDict:

View File

@ -42,11 +42,9 @@ class LoggerConnector:
self,
logger: Union[bool, Logger, Iterable[Logger]],
log_every_n_steps: int,
move_metrics_to_cpu: bool,
) -> None:
self.configure_logger(logger)
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
@property
def should_update_logs(self) -> bool:

View File

@ -157,7 +157,6 @@ class Trainer:
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
) -> None:
@ -334,10 +333,6 @@ class Trainer:
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.
move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
This can save some gpu memory, but can make training slower. Use with attention.
Default: ``False``.
multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
@ -451,7 +446,7 @@ class Trainer:
# init logger flags
self._loggers: List[Logger]
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
self._logger_connector.on_trainer_init(logger, log_every_n_steps)
# init debugging flags
self.val_check_batch: Union[int, float]

View File

@ -177,7 +177,6 @@ def test_memory_consumption_validation(tmpdir):
devices=1,
default_root_dir=tmpdir,
fast_dev_run=2,
move_metrics_to_cpu=True,
enable_model_summary=False,
)
trainer.fit(BoringLargeBatchModel())

View File

@ -677,12 +677,16 @@ def test_multiple_dataloaders_reset(val_check_interval, tmpdir):
@pytest.mark.parametrize(
"accelerator",
[
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param("cuda", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps", marks=RunIf(mps=True)),
],
)
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir, accelerator):
def test_metrics_and_outputs_device(tmpdir, accelerator):
class TestModel(BoringModel):
def on_before_backward(self, loss: Tensor) -> None:
# the loss should be on the correct device before backward
assert loss.device.type == accelerator
def validation_step(self, *args):
x = torch.tensor(2.0, requires_grad=True, device=self.device)
y = x * 2
@ -695,13 +699,12 @@ def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir, accelerator):
def validation_epoch_end(self, outputs):
# the step outputs were not moved
assert all(o.device == self.device for o in outputs)
# but the logging results were
assert self.trainer.callback_metrics["foo"].device.type == "cpu"
# and the logged metrics aren't
assert self.trainer.callback_metrics["foo"].device.type == accelerator
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, accelerator=accelerator, devices=1
)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator, devices=1)
trainer.fit(model)
trainer.validate(model, verbose=False)

View File

@ -718,23 +718,6 @@ def test_sanity_metrics_are_reset(tmpdir):
assert "val_loss" not in trainer.progress_bar_metrics
@RunIf(min_cuda_gpus=1)
def test_move_metrics_to_cpu(tmpdir):
class TestModel(BoringModel):
def on_before_backward(self, loss: Tensor) -> None:
assert loss.device.type == "cuda"
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
precision=16,
move_metrics_to_cpu=True,
accelerator="gpu",
devices=1,
)
trainer.fit(TestModel())
def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
class TestModel(BoringModel):
def on_train_epoch_end(self):