Fix mypy checks for double precision plugin (#7151)

This commit is contained in:
ananthsub 2021-04-22 03:29:38 -07:00 committed by GitHub
parent d6470bf193
commit 3f1a08ab00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 7 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from contextlib import contextmanager
from functools import wraps
from typing import Any, Generator, List, Tuple
@ -92,8 +92,8 @@ class DoublePrecisionPlugin(PrecisionPlugin):
while len(self.patches) > 0:
self.patches.pop().teardown()
@contextlib.contextmanager
def tensor_type_context(self) -> Generator:
@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
@ -102,7 +102,32 @@ class DoublePrecisionPlugin(PrecisionPlugin):
yield
torch.set_default_tensor_type(torch.FloatTensor)
train_step_context = tensor_type_context
val_step_context = tensor_type_context
test_step_context = tensor_type_context
predict_step_context = tensor_type_context
@contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
@contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)