From cee5eaf659b1e52b909f931d05b64d6f25dc63e8 Mon Sep 17 00:00:00 2001 From: Peter Yu <2057325+yukw777@users.noreply.github.com> Date: Thu, 20 Aug 2020 07:45:22 -0400 Subject: [PATCH] flake8 fixes (#3064) * flake8 fixes * fix pep8 * fix pep8 Co-authored-by: William Falcon --- .../callbacks/gpu_usage_logger.py | 39 +- pytorch_lightning/core/memory.py | 2 +- pytorch_lightning/core/step_result.py | 362 +++++++++--------- pytorch_lightning/trainer/evaluation_loop.py | 27 +- tests/base/model_test_steps.py | 38 +- 5 files changed, 241 insertions(+), 227 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_usage_logger.py b/pytorch_lightning/callbacks/gpu_usage_logger.py index 7f40d11651..a5d49998f3 100644 --- a/pytorch_lightning/callbacks/gpu_usage_logger.py +++ b/pytorch_lightning/callbacks/gpu_usage_logger.py @@ -80,9 +80,15 @@ class GpuUsageLogger(Callback): """ - def __init__(self, memory_utilisation: bool = True, gpu_utilisation: bool = True, - intra_step_time: bool = False, inter_step_time: bool = False, - fan_speed: bool = False, temperature: bool = False): + def __init__( + self, + memory_utilisation: bool = True, + gpu_utilisation: bool = True, + intra_step_time: bool = False, + inter_step_time: bool = False, + fan_speed: bool = False, + temperature: bool = False, + ): super(GpuUsageLogger).__init__() self.memory_utilisation = memory_utilisation self.gpu_utilisation = gpu_utilisation @@ -102,9 +108,10 @@ class GpuUsageLogger(Callback): if self.inter_step_time: # First log at beginning of second step if self.snap_inter_step_time: - trainer.logger.log_metrics({'Batch_Time/inter_step (ms)': - (time.time() - self.snap_inter_step_time) * 1000}, - step=trainer.global_step) + trainer.logger.log_metrics( + {'Batch_Time/inter_step (ms)': (time.time() - self.snap_inter_step_time) * 1000}, + step=trainer.global_step, + ) if self.intra_step_time: self.snap_intra_step_time = time.time() @@ -125,9 +132,10 @@ class GpuUsageLogger(Callback): if self.intra_step_time: if self.snap_intra_step_time: - trainer.logger.log_metrics({'Batch_Time/intra_step (ms)': - (time.time() - self.snap_intra_step_time) * 1000}, - step=trainer.global_step) + trainer.logger.log_metrics( + {'Batch_Time/intra_step (ms)': (time.time() - self.snap_intra_step_time) * 1000}, + step=trainer.global_step, + ) def on_train_epoch_start(self, trainer, pl_module): self.snap_intra_step_time = None @@ -135,10 +143,13 @@ class GpuUsageLogger(Callback): @staticmethod def _get_gpu_stat(pitem: str, unit: str): - result = subprocess.run(["/bin/nvidia-smi", f"--query-gpu={pitem}", "--format=csv,nounits,noheader"], - encoding="utf-8", stdout=subprocess.PIPE, - stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 - check=True) + result = subprocess.run( + ["/bin/nvidia-smi", f"--query-gpu={pitem}", "--format=csv,nounits,noheader"], + encoding="utf-8", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, # for backward compatibility with python version 3.6 + check=True, + ) try: gpu_usage = [float(x) for x in result.stdout.strip().split(os.linesep)] except ValueError: @@ -152,4 +163,4 @@ class GpuUsageLogger(Callback): def _log_memory(self, trainer): trainer.logger.log_metrics(self._get_gpu_stat("memory.used", "MB"), step=trainer.global_step) trainer.logger.log_metrics(self._get_gpu_stat("memory.free", "MB"), step=trainer.global_step) - trainer.logger.log_metrics(self._get_gpu_stat("utilization.memory", "%"), step=trainer.global_step) \ No newline at end of file + trainer.logger.log_metrics(self._get_gpu_stat("utilization.memory", "%"), step=trainer.global_step) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index d0a5ab8866..3228daae83 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -222,7 +222,7 @@ class ModelSummary(object): input_ = model.transfer_batch_to_device(input_, model.device) if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu: - model.forward = torch.cuda.amp.autocast()(model.forward) + model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training model.eval() diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 6ecf8eede2..612f8ee3b2 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -24,13 +24,12 @@ from pytorch_lightning.metrics.converters import _sync_ddp_if_available class Result(Dict): - def __init__( - self, - minimize: Optional[Tensor] = None, - early_stop_on: Optional[Tensor] = None, - checkpoint_on: Union[Tensor, bool, None] = None, - hiddens: Optional[Tensor] = None, + self, + minimize: Optional[Tensor] = None, + early_stop_on: Optional[Tensor] = None, + checkpoint_on: Union[Tensor, bool, None] = None, + hiddens: Optional[Tensor] = None, ): super().__init__() @@ -52,12 +51,7 @@ class Result(Dict): if minimize is not None and checkpoint_on is None: self.checkpoint_on = minimize.detach() - self['meta'] = { - '_internal': { - '_reduce_on_epoch': False, - 'batch_sizes': [] - } - } + self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}} def __getitem__(self, key: Union[str, Any]) -> Any: try: @@ -109,20 +103,20 @@ class Result(Dict): assert x.grad_fn is not None, m def log( - self, - name: str, - value: Any, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None + self, + name: str, + value: Any, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -140,37 +134,60 @@ class Result(Dict): if on_step and on_epoch: # set step version step_name = f'step_{name}' - self.__set_meta(step_name, value, prog_bar, logger, - on_step=True, on_epoch=False, - reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) + self.__set_meta( + step_name, + value, + prog_bar, + logger, + on_step=True, + on_epoch=False, + reduce_fx=reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + ) self.__setitem__(step_name, value) # set epoch version epoch_name = f'epoch_{name}' - self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, - reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) + self.__set_meta( + epoch_name, + value, + prog_bar, + logger, + on_step=False, + on_epoch=True, + reduce_fx=reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + ) self.__setitem__(epoch_name, value) else: - self.__set_meta(name, value, - prog_bar, logger, - on_step, on_epoch, - reduce_fx, - tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token) + self.__set_meta( + name, + value, + prog_bar, + logger, + on_step, + on_epoch, + reduce_fx, + tbptt_reduce_fx=tbptt_reduce_fx, + tbptt_pad_token=tbptt_pad_token, + ) # set the value self.__setitem__(name, value) def __set_meta( - self, - name: str, - value: Any, - prog_bar: bool, - logger: bool, - on_step: bool, - on_epoch: bool, - reduce_fx: Callable, - tbptt_pad_token: int, - tbptt_reduce_fx: Callable + self, + name: str, + value: Any, + prog_bar: bool, + logger: bool, + on_step: bool, + on_epoch: bool, + reduce_fx: Callable, + tbptt_pad_token: int, + tbptt_reduce_fx: Callable, ): # set the meta for the item meta_value = value @@ -182,7 +199,7 @@ class Result(Dict): reduce_fx=reduce_fx, value=meta_value, tbptt_reduce_fx=tbptt_reduce_fx, - tbptt_pad_token=tbptt_pad_token + tbptt_pad_token=tbptt_pad_token, ) self['meta'][name] = meta @@ -200,10 +217,7 @@ class Result(Dict): return torch.tensor(meta['_internal']['batch_sizes']) def get_callback_metrics(self) -> dict: - result = { - 'early_stop_on': self.early_stop_on, - 'checkpoint_on': self.checkpoint_on - } + result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on} return result @@ -342,7 +356,6 @@ class Result(Dict): result = recursive_gather(outputs, result) recursive_stack(result) - for k, option in meta.items(): if k == '_internal': continue @@ -457,13 +470,12 @@ def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]: class TrainResult(Result): - def __init__( - self, - minimize: Optional[Tensor] = None, - early_stop_on: Tensor = None, - checkpoint_on: Union[Tensor, bool] = None, - hiddens: Optional[Tensor] = None, + self, + minimize: Optional[Tensor] = None, + early_stop_on: Tensor = None, + checkpoint_on: Union[Tensor, bool] = None, + hiddens: Optional[Tensor] = None, ): """ Used in train loop to auto-log to a logger or progress bar without needing to define @@ -493,20 +505,20 @@ class TrainResult(Result): super().__init__(minimize, early_stop_on, checkpoint_on, hiddens) def log( - self, - name, - value, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = True, - on_epoch: bool = False, - reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None + self, + name, + value, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = True, + on_epoch: bool = False, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, ): """ Log a key, value @@ -543,34 +555,36 @@ class TrainResult(Result): sync_dist_op: the op to sync across sync_dist_group: the ddp group """ - super().log(name=name, - value=value, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx) + super().log( + name=name, + value=value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + ) def log_dict( - self, - dictionary: dict, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None + self, + dictionary: dict, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, ): """ Log a dictonary of values at once @@ -595,28 +609,29 @@ class TrainResult(Result): sync_dist_group: the ddp group: """ for k, v in dictionary.items(): - self.log(name=k, - value=v, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx) + self.log( + name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + ) class EvalResult(Result): - def __init__( - self, - early_stop_on: Optional[Tensor] = None, - checkpoint_on: Optional[Tensor] = None, - hiddens: Optional[Tensor] = None, + self, + early_stop_on: Optional[Tensor] = None, + checkpoint_on: Optional[Tensor] = None, + hiddens: Optional[Tensor] = None, ): """ Used in val/train loop to auto-log to a logger or progress bar without needing to define @@ -645,20 +660,20 @@ class EvalResult(Result): super().__init__(None, early_stop_on, checkpoint_on, hiddens) def log( - self, - name, - value, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None + self, + name, + value, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, ): """ Log a key, value @@ -694,34 +709,36 @@ class EvalResult(Result): sync_dist_op: the op to sync across sync_dist_group: the ddp group """ - super().log(name=name, - value=value, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx) + super().log( + name=name, + value=value, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + ) def log_dict( - self, - dictionary: dict, - prog_bar: bool = False, - logger: bool = True, - on_step: bool = False, - on_epoch: bool = True, - reduce_fx: Callable = torch.mean, - tbptt_reduce_fx: Callable = torch.mean, - tbptt_pad_token: int = 0, - enable_graph: bool = False, - sync_dist: bool = False, - sync_dist_op: Union[Any, str] = 'mean', - sync_dist_group: Optional[Any] = None + self, + dictionary: dict, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, ): """ Log a dictonary of values at once @@ -746,25 +763,24 @@ class EvalResult(Result): sync_dist_group: the ddp group """ for k, v in dictionary.items(): - self.log(name=k, - value=v, - prog_bar=prog_bar, - logger=logger, - on_step=on_step, - on_epoch=on_epoch, - reduce_fx=reduce_fx, - enable_graph=enable_graph, - sync_dist=sync_dist, - sync_dist_group=sync_dist_group, - sync_dist_op=sync_dist_op, - tbptt_pad_token=tbptt_pad_token, - tbptt_reduce_fx=tbptt_reduce_fx) + self.log( + name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + ) def get_callback_metrics(self) -> dict: - result = { - 'val_early_stop_on': self.early_stop_on, - 'val_checkpoint_on': self.checkpoint_on - } + result = {'val_early_stop_on': self.early_stop_on, 'val_checkpoint_on': self.checkpoint_on} return result diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4ef1b34b3c..a05c8a68b1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -255,7 +255,7 @@ class TrainerEvaluationLoopMixin(ABC): model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], - test_mode: bool = False + test_mode: bool = False, ): """Run evaluation code. @@ -316,13 +316,13 @@ class TrainerEvaluationLoopMixin(ABC): if self.is_overridden('on_test_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_start'): - model_ref.on_test_batch_start(output) + model_ref.on_test_batch_start(batch, batch_idx, dataloader_idx) else: self.on_validation_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_start'): - model_ref.on_validation_batch_start(output) + model_ref.on_validation_batch_start(batch, batch_idx, dataloader_idx) # ----------------- # RUN EVALUATION STEP # ----------------- @@ -364,13 +364,13 @@ class TrainerEvaluationLoopMixin(ABC): if self.is_overridden('on_test_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_end'): - model_ref.on_test_batch_end(output) + model_ref.on_test_batch_end(batch, batch_idx, dataloader_idx) else: self.on_validation_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_end'): - model_ref.on_validation_batch_end(output) + model_ref.on_validation_batch_end(batch, batch_idx, dataloader_idx) # track outputs for collation if output is not None: @@ -456,8 +456,11 @@ class TrainerEvaluationLoopMixin(ABC): eval_results = model.test_end(eval_results) user_reduced = True - rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' - ' Use `test_epoch_end` instead.', DeprecationWarning) + rank_zero_warn( + 'Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' + ' Use `test_epoch_end` instead.', + DeprecationWarning, + ) elif self.is_overridden('test_epoch_end', model=model): if using_eval_result: @@ -474,8 +477,11 @@ class TrainerEvaluationLoopMixin(ABC): eval_results = model.validation_end(eval_results) user_reduced = True - rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' - ' Use `validation_epoch_end` instead.', DeprecationWarning) + rank_zero_warn( + 'Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' + ' Use `validation_epoch_end` instead.', + DeprecationWarning, + ) elif self.is_overridden('validation_epoch_end', model=model): if using_eval_result: @@ -647,8 +653,7 @@ class TrainerEvaluationLoopMixin(ABC): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - if (test_mode and len(self.test_dataloaders) > 1) \ - or (not test_mode and len(self.val_dataloaders) > 1): + if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1): args.append(dataloader_idx) # handle DP, DDP forward diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index db5ad1ed33..3c96cc1731 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -33,20 +33,15 @@ class TestStepVariations(ABC): # alternate possible outputs to test if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc}) return output if batch_idx % 2 == 0: return test_acc if batch_idx % 3 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - 'test_dic': {'test_loss_a': loss_test} - }) + output = OrderedDict({'test_loss': loss_test, + 'test_acc': test_acc, + 'test_dic': {'test_loss_a': loss_test}}) return output def test_step_result_obj(self, batch, batch_idx, *args, **kwargs): @@ -71,19 +66,13 @@ class TestStepVariations(ABC): result = EvalResult() # alternate possible outputs to test if batch_idx % 1 == 0: - result.log_dict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + result.log_dict({'test_loss': loss_test, 'test_acc': test_acc}) return result if batch_idx % 2 == 0: return test_acc if batch_idx % 3 == 0: - result.log_dict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + result.log_dict({'test_loss': loss_test, 'test_acc': test_acc}) result.test_dic = {'test_loss_a': loss_test} return result @@ -108,10 +97,7 @@ class TestStepVariations(ABC): # alternate possible outputs to test if batch_idx % 1 == 0: - output = OrderedDict({ - 'test_loss': loss_test, - 'test_acc': test_acc, - }) + output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc}) return output if batch_idx % 2 == 0: return test_acc @@ -124,16 +110,12 @@ class TestStepVariations(ABC): }) return output if batch_idx % 5 == 0: - output = OrderedDict({ - f'test_loss_{dataloader_idx}': loss_test, - f'test_acc_{dataloader_idx}': test_acc, - }) + output = OrderedDict({f'test_loss_{dataloader_idx}': loss_test, f'test_acc_{dataloader_idx}': test_acc}) return output def test_step__empty(self, batch, batch_idx, *args, **kwargs): return {} - def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None): x, y = batch x = x.view(x.size(0), -1) @@ -174,13 +156,13 @@ class TestStepVariations(ABC): elif option == 1: result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) result.write('preds', labels_hat, prediction_file) - + # write multi-dimension elif option == 2: result.write('idxs', lazy_ids, prediction_file) result.write('preds', labels_hat, prediction_file) result.write('x', x, prediction_file) - + # write str list elif option == 3: result.write('idxs', lazy_ids, prediction_file)