Update accelerator.py (#7318)
This commit is contained in:
parent
badd0bba30
commit
39274273a4
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -114,7 +114,7 @@ class Accelerator:
|
|||
def _move_optimizer_state(self) -> None:
|
||||
""" Moves the state of the optimizers to the GPU if needed. """
|
||||
for opt in self.optimizers:
|
||||
state = defaultdict(dict)
|
||||
state: DefaultDict = defaultdict(dict)
|
||||
for p, v in opt.state.items():
|
||||
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
|
||||
opt.state = state
|
||||
|
|
Loading…
Reference in New Issue