diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 388d249c73..6d985a0f4e 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib +from contextlib import contextmanager from functools import wraps from typing import Any, Generator, List, Tuple @@ -92,8 +92,8 @@ class DoublePrecisionPlugin(PrecisionPlugin): while len(self.patches) > 0: self.patches.pop().teardown() - @contextlib.contextmanager - def tensor_type_context(self) -> Generator: + @contextmanager + def train_step_context(self) -> Generator[None, None, None]: """ A context manager to change the default tensor type. See: :meth:`torch.set_default_tensor_type` @@ -102,7 +102,32 @@ class DoublePrecisionPlugin(PrecisionPlugin): yield torch.set_default_tensor_type(torch.FloatTensor) - train_step_context = tensor_type_context - val_step_context = tensor_type_context - test_step_context = tensor_type_context - predict_step_context = tensor_type_context + @contextmanager + def val_step_context(self) -> Generator[None, None, None]: + """ + A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` + """ + torch.set_default_tensor_type(torch.DoubleTensor) + yield + torch.set_default_tensor_type(torch.FloatTensor) + + @contextmanager + def test_step_context(self) -> Generator[None, None, None]: + """ + A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` + """ + torch.set_default_tensor_type(torch.DoubleTensor) + yield + torch.set_default_tensor_type(torch.FloatTensor) + + @contextmanager + def predict_step_context(self) -> Generator[None, None, None]: + """ + A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` + """ + torch.set_default_tensor_type(torch.DoubleTensor) + yield + torch.set_default_tensor_type(torch.FloatTensor)