Update accelerator.py (#7318)

This commit is contained in:
ananthsub 2021-05-03 08:17:26 -07:00 committed by GitHub
parent badd0bba30
commit 39274273a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
import torch
from torch import Tensor
@ -114,7 +114,7 @@ class Accelerator:
def _move_optimizer_state(self) -> None:
""" Moves the state of the optimizers to the GPU if needed. """
for opt in self.optimizers:
state = defaultdict(dict)
state: DefaultDict = defaultdict(dict)
for p, v in opt.state.items():
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
opt.state = state