from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException class TPUAccelerator(Accelerator): def setup(self, trainer, model): if isinstance(self.precision_plugin, MixedPrecisionPlugin): raise MisconfigurationException( "amp + tpu is not supported. " "Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin" ) if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model)