Update accelerator.py (#7318)

This commit is contained in:
ananthsub 2021-05-03 08:17:26 -07:00 committed by GitHub
parent badd0bba30
commit 39274273a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -114,7 +114,7 @@ class Accelerator:
def _move_optimizer_state(self) -> None: def _move_optimizer_state(self) -> None:
""" Moves the state of the optimizers to the GPU if needed. """ """ Moves the state of the optimizers to the GPU if needed. """
for opt in self.optimizers: for opt in self.optimizers:
state = defaultdict(dict) state: DefaultDict = defaultdict(dict)
for p, v in opt.state.items(): for p, v in opt.state.items():
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
opt.state = state opt.state = state