diff --git a/CHANGELOG.md b/CHANGELOG.md index 35404e85a5..f0b3d5f125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `DataLoaderIterDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020)) +- Added bfloat16 support for Lightning Trainer ([#9049](https://github.com/PyTorchLightning/pytorch-lightning/pull/9049)) + + - Added `DataFetcher` within `Fit / Evaluation` Loop ([#9047](https://github.com/PyTorchLightning/pytorch-lightning/pull/9047)) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index e67756e9cf..ae9f261085 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -12,32 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator +from typing import Any, Callable, Dict, Generator, Union import torch from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): - """Plugin for native mixed precision training with :mod:`torch.cuda.amp`.""" + """ + Plugin for native mixed precision training with :mod:`torch.cuda.amp`. - def __init__(self) -> None: + Args: + precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16). + """ + + def __init__(self, precision: Union[int, str] = 16) -> None: super().__init__() + if not _NATIVE_AMP_AVAILABLE: raise MisconfigurationException( "You have asked for native AMP but your PyTorch version does not support it." " Consider upgrading with `pip install torch>=1.6`." ) - + self._fast_dtype = self._select_precision_dtype(precision) self.backend = AMPType.NATIVE - self.scaler = torch.cuda.amp.GradScaler() + if not self.is_bfloat16: + self.scaler = torch.cuda.amp.GradScaler() + + def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype: + if precision == "bf16": + if not _TORCH_GREATER_EQUAL_1_10: + raise MisconfigurationException( + "To use bfloat16 with native amp you must install torch greater or equal to 1.10." + ) + return torch.bfloat16 + return torch.float16 + + @property + def is_bfloat16(self) -> bool: + return self._fast_dtype == torch.bfloat16 def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: + if self.is_bfloat16: + return super().pre_backward(model, closure_loss) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) @@ -49,6 +71,9 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): lambda_closure: Callable, **kwargs: Any, ) -> bool: + if self.is_bfloat16: + # skip scaler logic, as bfloat16 does not require scaler + return super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." @@ -65,33 +90,39 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): self.scaler.update() return False + def autocast_context_manager(self) -> torch.cuda.amp.autocast: + if self.is_bfloat16: + return torch.cuda.amp.autocast(fast_dtype=self._fast_dtype) + return torch.cuda.amp.autocast() + @contextmanager def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" - with torch.cuda.amp.autocast(): + with self.autocast_context_manager(): yield @contextmanager def val_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" - with torch.cuda.amp.autocast(): + with self.autocast_context_manager(): yield @contextmanager def test_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" - with torch.cuda.amp.autocast(): + with self.autocast_context_manager(): yield @contextmanager def predict_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" - with torch.cuda.amp.autocast(): + with self.autocast_context_manager(): yield def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if "native_amp_scaling_state" in checkpoint: + if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16: self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"]) def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint["native_amp_scaling_state"] = self.scaler.state_dict() + if not self.is_bfloat16: + checkpoint["native_amp_scaling_state"] = self.scaler.state_dict() diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index e338b0af09..861e5e1363 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -24,8 +24,8 @@ if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training""" - def __init__(self) -> None: - super().__init__() + def __init__(self, precision: Union[int, str] = 16) -> None: + super().__init__(precision) self.scaler = ShardedGradScaler() def clip_grad_by_norm( diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index ec3b56489e..b0aabddf2d 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -560,7 +560,7 @@ class AcceleratorConnector: return PrecisionPlugin() if self.precision == 64: return DoublePrecisionPlugin() - if self.precision == 16: + if self.precision in (16, "bf16"): if self.use_tpu: return TPUHalfPrecisionPlugin() @@ -581,12 +581,12 @@ class AcceleratorConnector: else: raise MisconfigurationException(msg) else: - log.info("Using native 16bit precision.") + log.info(f"Using native {self.precision} bit Automatic Mixed Precision") if self._is_sharded_training_type: - return ShardedNativeMixedPrecisionPlugin() + return ShardedNativeMixedPrecisionPlugin(self.precision) if self._is_fully_sharded_training_type: - return FullyShardedNativeMixedPrecisionPlugin() - return NativeMixedPrecisionPlugin() + return FullyShardedNativeMixedPrecisionPlugin(self.precision) + return NativeMixedPrecisionPlugin(self.precision) if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 739e34aa24..6b39bb5159 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -138,7 +138,7 @@ class Trainer( log_every_n_steps: int = 50, accelerator: Optional[Union[str, Accelerator]] = None, sync_batchnorm: bool = False, - precision: int = 32, + precision: Union[int, str] = 32, weights_summary: Optional[str] = "top", weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, @@ -260,8 +260,8 @@ class Trainer( plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. - precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or - TPUs. + precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 7d163d3a73..747c0be617 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -47,6 +47,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, + _TORCH_GREATER_EQUAL_1_10, _TORCH_QUANTIZE_AVAILABLE, _TORCH_SHARDED_TENSOR_AVAILABLE, _TORCHTEXT_AVAILABLE, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index bed2461395..b9757715a3 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -253,6 +253,10 @@ def add_argparse_args( if arg == "track_grad_norm": use_type = float + # hack for precision + if arg == "precision": + use_type = _precision_allowed_type + parser.add_argument( f"--{arg}", dest=arg, default=arg_default, type=use_type, help=args_help.get(arg), **arg_kwargs ) @@ -302,3 +306,16 @@ def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: if "." in str(x): return float(x) return int(x) + + +def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]: + """ + >>> _precision_allowed_type("32") + 32 + >>> _precision_allowed_type("bf16") + 'bf16' + """ + try: + return int(x) + except ValueError: + return x diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 426359f597..fa6598f884 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -68,6 +68,8 @@ _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") _TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1") _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") +_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") + _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available("pl_bolts") diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index ea02403f6a..79c0cf7c12 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -31,7 +32,8 @@ class AMPTestModel(BoringModel): def _step(self, batch, batch_idx): assert torch.is_autocast_enabled() output = self(batch) - assert output.dtype == torch.float16 + bfloat16 = self.trainer.precision_plugin.is_bfloat16 + assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16 loss = self.loss(batch, output) return loss @@ -50,67 +52,42 @@ class AMPTestModel(BoringModel): def predict(self, batch, batch_idx, dataloader_idx=None): assert torch.is_autocast_enabled() output = self(batch) - assert output.dtype == torch.float16 + bfloat16 = self.trainer.precision_plugin.is_bfloat16 + assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16 return output -@pytest.mark.skip(reason="dp + amp not supported currently") # TODO -@RunIf(min_gpus=1) -def test_amp_single_gpu_dp(tmpdir): - """Make sure DP/DDP + AMP work.""" - tutils.reset_seed() - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="dp", precision=16) - - model = AMPTestModel() - # tutils.run_model_test(trainer_options, model) - trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) - - assert trainer.state.finished, f"Training failed with {trainer.state}" - - -@RunIf(min_gpus=1) -def test_amp_single_gpu_ddp_spawn(tmpdir): - """Make sure DP/DDP + AMP work.""" - tutils.reset_seed() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=1, accelerator="ddp_spawn", precision=16) - - model = AMPTestModel() - # tutils.run_model_test(trainer_options, model) - trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert trainer.state.finished, f"Training failed with {trainer.state}" - - -@pytest.mark.skip(reason="dp + amp not supported currently") # TODO -@RunIf(min_gpus=1) -def test_amp_multi_gpu_dp(tmpdir): - """Make sure DP/DDP + AMP work.""" - tutils.reset_seed() - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="dp", precision=16) - - model = AMPTestModel() - # tutils.run_model_test(trainer_options, model) - trainer.fit(model) - - assert trainer.state.finished, f"Training failed with {trainer.state}" - - @RunIf(min_gpus=2) -def test_amp_multi_gpu_ddp_spawn(tmpdir): - """Make sure DP/DDP + AMP work.""" +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("dp", marks=pytest.mark.skip("dp + amp not supported currently")), # TODO + "ddp_spawn", + ], +) +@pytest.mark.parametrize( + "precision", + [ + 16, + pytest.param( + "bf16", + marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"), + ), + ], +) +@pytest.mark.parametrize("gpus", [1, 2]) +def test_amp_gpus(tmpdir, accelerator, precision, gpus): + """Make sure combinations of AMP and training types work if supported.""" tutils.reset_seed() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=2, accelerator="ddp_spawn", precision=16) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, accelerator=accelerator, precision=precision) model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) + assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index d5862635c7..15ec43973b 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -21,6 +21,7 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -69,6 +70,8 @@ def test_amp_apex_ddp( plugins=[plugin_cls()] if custom_plugin else None, ) assert isinstance(trainer.precision_plugin, plugin_cls) + if amp == "native": + assert not trainer.precision_plugin.is_bfloat16 class GradientUnscaleBoringModel(BoringModel): @@ -174,3 +177,16 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) model = BoringModel() trainer.fit(model) + + +@RunIf(min_gpus=1, amp_native=True, max_torch="1.9") +def test_amp_precision_16_bfloat_throws_error(tmpdir): + with pytest.raises( + MisconfigurationException, + match="To use bfloat16 with native amp you must install torch greater or equal to 1.10", + ): + Trainer( + default_root_dir=tmpdir, + precision="bf16", + gpus=1, + ) diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index 979cb4f835..8672795ea2 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities.argparse import ( _gpus_allowed_type, _int_or_float_type, _parse_args_from_docstring, + _precision_allowed_type, add_argparse_args, from_argparse_args, parse_argparser, @@ -215,3 +216,20 @@ def test_gpus_allowed_type(): def test_int_or_float_type(): assert isinstance(_int_or_float_type("0.0"), float) assert isinstance(_int_or_float_type("0"), int) + + +@pytest.mark.parametrize(["arg", "expected"], [["--precision=16", 16], ["--precision=bf16", "bf16"]]) +def test_precision_parsed_correctly(arg, expected): + """ + Test to ensure that the precision flag is passed correctly when adding argparse args. + """ + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + fake_argv = [arg] + args = parser.parse_args(fake_argv) + assert args.precision == expected + + +def test_precision_type(): + assert _precision_allowed_type("bf16") == "bf16" + assert _precision_allowed_type("16") == 16