diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a7a8229762..66e92ae006 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -242,7 +242,7 @@ class Accelerator: args[0] = batch - with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): + with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*args) def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index 64f741a9a2..ed45d0bc68 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -41,6 +41,6 @@ class Plugin(ABC): yield @contextlib.contextmanager - def predict_context(self) -> Generator: + def predict_step_context(self) -> Generator: """A contextmanager for the predict step""" yield diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 268b7480d4..388d249c73 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from functools import wraps -from typing import Any, List, Tuple +from typing import Any, Generator, List, Tuple import torch import torch.nn as nn @@ -90,3 +91,18 @@ class DoublePrecisionPlugin(PrecisionPlugin): def post_dispatch(self) -> None: while len(self.patches) > 0: self.patches.pop().teardown() + + @contextlib.contextmanager + def tensor_type_context(self) -> Generator: + """ + A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` + """ + torch.set_default_tensor_type(torch.DoubleTensor) + yield + torch.set_default_tensor_type(torch.FloatTensor) + + train_step_context = tensor_type_context + val_step_context = tensor_type_context + test_step_context = tensor_type_context + predict_step_context = tensor_type_context diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 82c55191a2..994b7f2613 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -115,7 +115,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): yield @contextmanager - def predict_context(self) -> Generator[None, None, None]: + def predict_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" with torch.cuda.amp.autocast(): yield diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index 175ca5ecab..96ff2d182b 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -37,25 +37,37 @@ class DoublePrecisionBoringModel(BoringModel): def training_step(self, batch, batch_idx): float_data, int_data = batch + assert torch.tensor([0.]).dtype == torch.float64 + assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16 assert float_data.dtype == torch.float64 output = self(float_data) loss = self.loss(batch, output) return {"loss": loss} + def training_epoch_end(self, outputs) -> None: + assert torch.tensor([0.]).dtype == torch.float32 + return super().training_epoch_end(outputs) + def validation_step(self, batch, batch_idx): assert batch.dtype == torch.float64 + assert torch.tensor([0.]).dtype == torch.float64 + assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16 output = self(batch) loss = self.loss(batch, output) return {"x": loss} def test_step(self, batch, batch_idx): assert batch.dtype == torch.float64 + assert torch.tensor([0.]).dtype == torch.float64 + assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16 output = self(batch) loss = self.loss(batch, output) return {"y": loss} def predict_step(self, batch, batch_idx, dataloader_idx=None): assert batch.dtype == torch.float64 + assert torch.tensor([0.]).dtype == torch.float64 + assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16 return self(batch) def on_fit_start(self):