2020-04-21 12:26:47 +00:00
|
|
|
def recursive_detach(in_dict: dict) -> dict:
|
2020-04-19 20:41:54 +00:00
|
|
|
"""Detach all tensors in `in_dict`.
|
|
|
|
|
|
|
|
May operate recursively if some of the values in `in_dict` are dictionaries
|
|
|
|
which contain instances of `torch.Tensor`. Other types in `in_dict` are
|
|
|
|
not affected by this utility function.
|
|
|
|
|
2020-04-21 12:26:47 +00:00
|
|
|
Args:
|
|
|
|
in_dict:
|
2020-04-19 20:41:54 +00:00
|
|
|
|
2020-04-21 12:26:47 +00:00
|
|
|
Return:
|
|
|
|
out_dict:
|
2020-04-19 20:41:54 +00:00
|
|
|
"""
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in in_dict.items():
|
|
|
|
if isinstance(v, dict):
|
|
|
|
out_dict.update({k: recursive_detach(v)})
|
|
|
|
elif callable(getattr(v, 'detach', None)):
|
|
|
|
out_dict.update({k: v.detach()})
|
|
|
|
else:
|
|
|
|
out_dict.update({k: v})
|
|
|
|
return out_dict
|