From ba2b532ee94721e9143036f41e120820e2c5cd21 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 9 Oct 2020 16:05:45 -0700 Subject: [PATCH] [rfc] Make Result.unpack_batch_size a static method (#4019) This could be a useful utility elsewhere in lightning for calculating the batch size --- pytorch_lightning/core/step_result.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ad34261f1e..847cd71c8f 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -219,7 +219,7 @@ class Result(Dict): def track_batch_size(self, batch): try: - batch_size = self.unpack_batch_size(batch) + batch_size = Result.unpack_batch_size(batch) except RecursionError as re: batch_size = 1 @@ -276,7 +276,7 @@ class Result(Dict): result[k] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate + # compute metric on epoch anyway so state does not accumulate self[k].compute() return result @@ -299,7 +299,7 @@ class Result(Dict): result[k] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # compute metric on epoch anyway so state does not accumulate + # compute metric on epoch anyway so state does not accumulate self[k].compute() return result @@ -353,7 +353,8 @@ class Result(Dict): newone[k] = copy(v) return newone - def unpack_batch_size(self, sample): + @staticmethod + def unpack_batch_size(sample): """ Recursively unpack sample to find a torch.Tensor. returns len(tensor) when found, or 1 when it hits an empty or non iterable. @@ -364,10 +365,10 @@ class Result(Dict): return len(sample) elif isinstance(sample, dict): sample = next(iter(sample.values()), 1) - size = self.unpack_batch_size(sample) + size = Result.unpack_batch_size(sample) elif isinstance(sample, Iterable): sample = next(iter(sample), 1) - size = self.unpack_batch_size(sample) + size = Result.unpack_batch_size(sample) else: size = 1 return size