Minor miscellaneous fixes (#18068)
This commit is contained in:
parent
e38c71b828
commit
3a55f0c0a1
|
@ -344,7 +344,7 @@ class MyCustomTrainer:
|
|||
Args:
|
||||
model: The LightningModule to train
|
||||
scheduler_cfg: The learning rate scheduler configuration.
|
||||
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
|
||||
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
|
||||
level: whether we are trying to step on epoch- or step-level
|
||||
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
|
||||
"""
|
||||
|
|
|
@ -46,7 +46,7 @@ class DoublePrecision(Precision):
|
|||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float64)
|
||||
|
|
|
@ -106,7 +106,7 @@ class FSDPPrecision(Precision):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
|
||||
|
|
|
@ -43,7 +43,7 @@ class HalfPrecision(Precision):
|
|||
def init_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type when initializing module parameters or tensors.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self._desired_input_dtype)
|
||||
|
@ -55,7 +55,7 @@ class HalfPrecision(Precision):
|
|||
"""A context manager to change the default tensor type when tensors get created during the module's
|
||||
forward.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
"""
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self._desired_input_dtype)
|
||||
|
|
|
@ -345,7 +345,6 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
raise NotImplementedError(
|
||||
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
|
||||
)
|
||||
empty_init = empty_init and not self.zero_stage_3
|
||||
base_context = super().module_init_context(empty_init=empty_init) if not self.zero_stage_3 else nullcontext()
|
||||
with base_context, self.module_sharded_context():
|
||||
yield
|
||||
|
|
|
@ -46,9 +46,9 @@ warning_cache = WarningCache()
|
|||
class ModelCheckpoint(Checkpoint):
|
||||
r"""
|
||||
Save the model periodically by monitoring a quantity. Every metric logged with
|
||||
:meth:`~lightning.pytorch.core.module.log` or :meth:`~lightning.pytorch.core.module.log_dict` in
|
||||
LightningModule is a candidate for the monitor key. For more information, see
|
||||
:ref:`checkpointing`.
|
||||
:meth:`~lightning.pytorch.core.module.LightningModule.log` or
|
||||
:meth:`~lightning.pytorch.core.module.LightningModule.log_dict` is a candidate for the monitor key.
|
||||
For more information, see :ref:`checkpointing`.
|
||||
|
||||
After training finishes, use :attr:`best_model_path` to retrieve the path to the
|
||||
best checkpoint file and :attr:`best_model_score` to retrieve its score.
|
||||
|
|
|
@ -46,7 +46,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
If unsupported ``precision`` is provided.
|
||||
"""
|
||||
|
||||
def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None:
|
||||
def __init__(self, precision: _PRECISION_INPUT) -> None:
|
||||
supported_precision = get_args(_PRECISION_INPUT)
|
||||
if precision not in supported_precision:
|
||||
raise ValueError(
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import Any, cast, Generator, List, Literal, Tuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import FloatTensor, Tensor
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import lightning.pytorch as pl
|
||||
|
@ -91,8 +91,9 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""A context manager to change the default tensor type.
|
||||
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
See: :meth:`torch.set_default_dtype`
|
||||
"""
|
||||
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.float64)
|
||||
yield
|
||||
torch.set_default_tensor_type(FloatTensor)
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Literal, Optional
|
||||
from typing import Any, Generator, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -20,12 +20,9 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
|||
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available():
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
else:
|
||||
MixedPrecision = None # type: ignore[misc,assignment]
|
||||
ShardedGradScaler = None # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
|
||||
|
@ -35,10 +32,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None
|
||||
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
|
||||
) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
|
||||
super().__init__(
|
||||
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None)
|
||||
)
|
||||
|
@ -54,8 +53,8 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
)
|
||||
|
||||
@property
|
||||
def mixed_precision_config(self) -> Optional[MixedPrecision]:
|
||||
assert MixedPrecision is not None
|
||||
def mixed_precision_config(self) -> "TorchMixedPrecision":
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
|
||||
|
||||
if self.precision == "16-mixed":
|
||||
param_dtype = torch.float32
|
||||
|
@ -70,7 +69,7 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
else:
|
||||
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
|
||||
|
||||
return MixedPrecision(
|
||||
return TorchMixedPrecision(
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
buffer_dtype=buffer_dtype,
|
||||
|
|
|
@ -325,6 +325,11 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
return config
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
if not isinstance(self.accelerator, CUDAAccelerator):
|
||||
raise RuntimeError(
|
||||
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
|
||||
" is used."
|
||||
)
|
||||
assert self.parallel_devices is not None
|
||||
_validate_device_index_selection(self.parallel_devices)
|
||||
reset_seed()
|
||||
|
@ -438,11 +443,6 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
|
||||
raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")
|
||||
|
||||
if not isinstance(self.accelerator, CUDAAccelerator):
|
||||
raise MisconfigurationException(
|
||||
f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used."
|
||||
)
|
||||
|
||||
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
if self.lightning_module.trainer and self.lightning_module.trainer.training:
|
||||
self._initialize_deepspeed_train(self.model)
|
||||
|
|
|
@ -78,7 +78,7 @@ def test_selected_dtype(precision, expected_dtype):
|
|||
("16-true", torch.float16),
|
||||
],
|
||||
)
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
def test_init_context(precision, expected_dtype):
|
||||
plugin = DeepSpeedPrecision(precision=precision)
|
||||
with plugin.init_context():
|
||||
model = torch.nn.Linear(2, 2)
|
||||
|
|
|
@ -37,7 +37,7 @@ def test_selected_dtype(precision, expected_dtype):
|
|||
("16-true", torch.half),
|
||||
],
|
||||
)
|
||||
def test_module_init_context(precision, expected_dtype):
|
||||
def test_init_context(precision, expected_dtype):
|
||||
plugin = HalfPrecision(precision=precision)
|
||||
with plugin.init_context():
|
||||
model = torch.nn.Linear(2, 2)
|
||||
|
|
|
@ -257,9 +257,7 @@ class HookedModel(BoringModel):
|
|||
return self._manual_train_batch(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _auto_train_batch(
|
||||
trainer, model, batches, device=torch.device("cpu"), current_epoch=0, current_batch=0, **kwargs
|
||||
):
|
||||
def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_batch=0, **kwargs):
|
||||
using_deepspeed = kwargs.get("strategy") == "deepspeed"
|
||||
out = []
|
||||
for i in range(current_batch, batches):
|
||||
|
@ -312,7 +310,7 @@ class HookedModel(BoringModel):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
|
||||
def _manual_train_batch(trainer, model, batches, device, **kwargs):
|
||||
using_deepspeed = kwargs.get("strategy") == "deepspeed"
|
||||
out = []
|
||||
for i in range(batches):
|
||||
|
@ -343,7 +341,7 @@ class HookedModel(BoringModel):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
def _eval_epoch(fn, trainer, model, batches, key, device=torch.device("cpu")):
|
||||
def _eval_epoch(fn, trainer, model, batches, key, device):
|
||||
return [
|
||||
{"name": f"Callback.on_{fn}_epoch_start", "args": (trainer, model)},
|
||||
{"name": f"on_{fn}_epoch_start"},
|
||||
|
@ -353,7 +351,7 @@ class HookedModel(BoringModel):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def _eval_batch(fn, trainer, model, batches, key, device=torch.device("cpu")):
|
||||
def _eval_batch(fn, trainer, model, batches, key, device):
|
||||
out = []
|
||||
outputs = {key: ANY}
|
||||
for i in range(batches):
|
||||
|
@ -373,13 +371,13 @@ class HookedModel(BoringModel):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
def _predict_batch(trainer, model, batches):
|
||||
def _predict_batch(trainer, model, batches, device):
|
||||
out = []
|
||||
for i in range(batches):
|
||||
out.extend(
|
||||
[
|
||||
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
|
||||
{"name": "transfer_batch_to_device", "args": (ANY, torch.device("cpu"), 0)},
|
||||
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
|
||||
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
|
||||
{"name": "Callback.on_predict_batch_start", "args": (trainer, model, ANY, i)},
|
||||
{"name": "on_predict_batch_start", "args": (ANY, i)},
|
||||
|
@ -451,7 +449,7 @@ def test_trainer_model_hook_system_fit(tmpdir, kwargs, automatic_optimization):
|
|||
using_deepspeed = kwargs.get("strategy") == "deepspeed"
|
||||
if kwargs.get("precision") == "16-mixed" and not using_deepspeed:
|
||||
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
|
||||
device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu")
|
||||
device = trainer.strategy.root_device
|
||||
expected = [
|
||||
{"name": "configure_callbacks"},
|
||||
{"name": "prepare_data"},
|
||||
|
@ -570,7 +568,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
|
|||
{"name": "on_train_start"},
|
||||
{"name": "Callback.on_train_epoch_start", "args": (trainer, model)},
|
||||
{"name": "on_train_epoch_start"},
|
||||
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
|
||||
*model._train_batch(trainer, model, 2, trainer.strategy.root_device, current_epoch=1, current_batch=0),
|
||||
{"name": "Callback.on_train_epoch_end", "args": (trainer, model)},
|
||||
{"name": "on_train_epoch_end"}, # before ModelCheckpoint because it's a "monitoring callback"
|
||||
# `ModelCheckpoint.save_checkpoint` is called here
|
||||
|
@ -648,7 +646,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
|
|||
{"name": "on_train_start"},
|
||||
{"name": "Callback.on_train_epoch_start", "args": (trainer, model)},
|
||||
{"name": "on_train_epoch_start"},
|
||||
*model._train_batch(trainer, model, steps_after_reload, current_batch=1),
|
||||
*model._train_batch(trainer, model, steps_after_reload, trainer.strategy.root_device, current_batch=1),
|
||||
{"name": "Callback.on_train_epoch_end", "args": (trainer, model)},
|
||||
{"name": "on_train_epoch_end"}, # before ModelCheckpoint because it's a "monitoring callback"
|
||||
# `ModelCheckpoint.save_checkpoint` is called here
|
||||
|
@ -691,7 +689,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
|
|||
{"name": "zero_grad"},
|
||||
{"name": f"Callback.on_{noun}_start", "args": (trainer, model)},
|
||||
{"name": f"on_{noun}_start"},
|
||||
*model._eval_epoch(noun, trainer, model, batches, key),
|
||||
*model._eval_epoch(noun, trainer, model, batches, key, trainer.strategy.root_device),
|
||||
{"name": f"Callback.on_{noun}_end", "args": (trainer, model)},
|
||||
{"name": f"on_{noun}_end"},
|
||||
{"name": "train", "args": (True,)},
|
||||
|
@ -733,7 +731,7 @@ def test_trainer_model_hook_system_predict(tmpdir):
|
|||
{"name": "on_predict_start"},
|
||||
{"name": "Callback.on_predict_epoch_start", "args": (trainer, model)},
|
||||
{"name": "on_predict_epoch_start"},
|
||||
*model._predict_batch(trainer, model, batches),
|
||||
*model._predict_batch(trainer, model, batches, trainer.strategy.root_device),
|
||||
{"name": "Callback.on_predict_epoch_end", "args": (trainer, model)},
|
||||
{"name": "on_predict_epoch_end"},
|
||||
{"name": "Callback.on_predict_end", "args": (trainer, model)},
|
||||
|
|
|
@ -43,6 +43,6 @@ def test_strategy_lightning_restore_optimizer_and_schedulers(tmpdir, restore_opt
|
|||
|
||||
model = BoringModel()
|
||||
strategy = TestStrategy(torch.device("cpu"))
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy)
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy, accelerator="cpu")
|
||||
trainer.fit(model, ckpt_path=checkpoint_path)
|
||||
assert strategy.load_optimizer_state_dict_called == restore_optimizer_and_schedulers
|
||||
|
|
|
@ -109,7 +109,10 @@ def test_deepspeed_strategy_string(tmpdir, strategy):
|
|||
set."""
|
||||
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True, default_root_dir=tmpdir, strategy=strategy if isinstance(strategy, str) else strategy()
|
||||
accelerator="cpu",
|
||||
fast_dev_run=True,
|
||||
default_root_dir=tmpdir,
|
||||
strategy=strategy if isinstance(strategy, str) else strategy(),
|
||||
)
|
||||
|
||||
assert isinstance(trainer.strategy, DeepSpeedStrategy)
|
||||
|
@ -124,7 +127,7 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
|
|||
f.write(json.dumps(deepspeed_config))
|
||||
monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path)
|
||||
|
||||
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed")
|
||||
trainer = Trainer(accelerator="cpu", fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed")
|
||||
|
||||
strategy = trainer.strategy
|
||||
assert isinstance(strategy, DeepSpeedStrategy)
|
||||
|
@ -1225,7 +1228,7 @@ def test_error_with_invalid_accelerator(tmpdir):
|
|||
fast_dev_run=True,
|
||||
)
|
||||
model = BoringModel()
|
||||
with pytest.raises(MisconfigurationException, match="DeepSpeed strategy is only supported on GPU"):
|
||||
with pytest.raises(RuntimeError, match="DeepSpeed strategy is only supported on CUDA"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
|
|
@ -203,7 +203,7 @@ def test_invalid_on_cpu(tmpdir):
|
|||
MisconfigurationException,
|
||||
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
|
||||
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
|
||||
assert isinstance(trainer.strategy, FSDPStrategy)
|
||||
trainer.strategy.setup_environment()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from tests_pytorch.helpers.runif import RunIf
|
|||
|
||||
def test_single_cpu():
|
||||
"""Tests if device is set correctly for single CPU strategy."""
|
||||
trainer = Trainer()
|
||||
trainer = Trainer(accelerator="cpu")
|
||||
assert isinstance(trainer.strategy, SingleDeviceStrategy)
|
||||
assert trainer.strategy.root_device == torch.device("cpu")
|
||||
|
||||
|
|
Loading…
Reference in New Issue