[rfc] Make Result.unpack_batch_size a static method (#4019)

This could be a useful utility elsewhere in lightning for calculating the batch size
This commit is contained in:
ananthsub 2020-10-09 16:05:45 -07:00 committed by GitHub
parent d41ffafa87
commit ba2b532ee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 6 deletions

View File

@ -219,7 +219,7 @@ class Result(Dict):
def track_batch_size(self, batch):
try:
batch_size = self.unpack_batch_size(batch)
batch_size = Result.unpack_batch_size(batch)
except RecursionError as re:
batch_size = 1
@ -276,7 +276,7 @@ class Result(Dict):
result[k] = self[k]
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
# compute metric on epoch anyway so state does not accumulate
self[k].compute()
return result
@ -299,7 +299,7 @@ class Result(Dict):
result[k] = self[k]
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
# compute metric on epoch anyway so state does not accumulate
self[k].compute()
return result
@ -353,7 +353,8 @@ class Result(Dict):
newone[k] = copy(v)
return newone
def unpack_batch_size(self, sample):
@staticmethod
def unpack_batch_size(sample):
"""
Recursively unpack sample to find a torch.Tensor.
returns len(tensor) when found, or 1 when it hits an empty or non iterable.
@ -364,10 +365,10 @@ class Result(Dict):
return len(sample)
elif isinstance(sample, dict):
sample = next(iter(sample.values()), 1)
size = self.unpack_batch_size(sample)
size = Result.unpack_batch_size(sample)
elif isinstance(sample, Iterable):
sample = next(iter(sample), 1)
size = self.unpack_batch_size(sample)
size = Result.unpack_batch_size(sample)
else:
size = 1
return size