25 lines
677 B
Python
25 lines
677 B
Python
|
def recursive_detach(in_dict):
|
||
|
"""Detach all tensors in `in_dict`.
|
||
|
|
||
|
May operate recursively if some of the values in `in_dict` are dictionaries
|
||
|
which contain instances of `torch.Tensor`. Other types in `in_dict` are
|
||
|
not affected by this utility function.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
in_dict : dict
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
out_dict : dict
|
||
|
"""
|
||
|
out_dict = {}
|
||
|
for k, v in in_dict.items():
|
||
|
if isinstance(v, dict):
|
||
|
out_dict.update({k: recursive_detach(v)})
|
||
|
elif callable(getattr(v, 'detach', None)):
|
||
|
out_dict.update({k: v.detach()})
|
||
|
else:
|
||
|
out_dict.update({k: v})
|
||
|
return out_dict
|