Fix mypy checks for double precision plugin (#7151)

This commit is contained in:
ananthsub 2021-04-22 03:29:38 -07:00 committed by GitHub
parent d6470bf193
commit 3f1a08ab00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 7 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Any, Generator, List, Tuple from typing import Any, Generator, List, Tuple
@ -92,8 +92,8 @@ class DoublePrecisionPlugin(PrecisionPlugin):
while len(self.patches) > 0: while len(self.patches) > 0:
self.patches.pop().teardown() self.patches.pop().teardown()
@contextlib.contextmanager @contextmanager
def tensor_type_context(self) -> Generator: def train_step_context(self) -> Generator[None, None, None]:
""" """
A context manager to change the default tensor type. A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type` See: :meth:`torch.set_default_tensor_type`
@ -102,7 +102,32 @@ class DoublePrecisionPlugin(PrecisionPlugin):
yield yield
torch.set_default_tensor_type(torch.FloatTensor) torch.set_default_tensor_type(torch.FloatTensor)
train_step_context = tensor_type_context @contextmanager
val_step_context = tensor_type_context def val_step_context(self) -> Generator[None, None, None]:
test_step_context = tensor_type_context """
predict_step_context = tensor_type_context A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
@contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)