From b0a0a47a0bd04d0ef7361a3f816f4259f9ec2c64 Mon Sep 17 00:00:00 2001 From: Alok Singh <8325708+alok@users.noreply.github.com> Date: Wed, 25 Sep 2019 16:05:06 -0700 Subject: [PATCH] Rename variables (#124) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - data_batch → batch - batch_i → batch_idx - dataloader_i → dataloader_idx - tng → training - training_dataloader → train_dataloader - add_log_row_interval → row_log_interval - gradient_clip → gradient_clip_val - prog → progress - tqdm_dic → tqdm_dict --- README.md | 20 +-- .../RequiredTrainerInterface.md | 78 +++++----- docs/LightningModule/properties.md | 2 +- docs/Trainer/Logging.md | 2 +- docs/Trainer/Training Loop.md | 4 +- docs/Trainer/hooks.md | 4 +- .../lightning_module_template.py | 22 +-- examples/templates/gan.py | 2 +- pytorch_lightning/root_module/hooks.py | 4 +- pytorch_lightning/root_module/root_module.py | 2 +- pytorch_lightning/testing/lm_test_module.py | 2 +- .../testing/lm_test_module_base.py | 12 +- .../testing/lm_test_module_mixins.py | 76 +++++----- pytorch_lightning/trainer/trainer.py | 134 +++++++++--------- pytorch_lightning/utilities/arg_parse.py | 14 +- tests/debug.py | 8 +- tests/test_models.py | 10 +- 17 files changed, 198 insertions(+), 198 deletions(-) diff --git a/README.md b/README.md index 33480dc4b4..2112aeaed2 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ class CoolSystem(pl.LightningModule): return torch.optim.Adam(self.parameters(), lr=0.02) @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @@ -177,16 +177,16 @@ You define the blue parts using the LightningModule interface: ```python # what to do in the training loop -def training_step(self, data_batch, batch_nb): +def training_step(self, batch, batch_nb): # what to do in the validation loop -def validation_step(self, data_batch, batch_nb): +def validation_step(self, batch, batch_nb): # how to aggregate validation_step outputs def validation_end(self, outputs): # and your dataloaders -def tng_dataloader(): +def train_dataloader(): def val_dataloader(): def test_dataloader(): ``` @@ -195,8 +195,8 @@ def test_dataloader(): ```python # define what happens for training here -def training_step(self, data_batch, batch_nb): - x, y = data_batch +def training_step(self, batch, batch_nb): + x, y = batch # define your own forward and loss calculation hidden_states = self.encoder(x) @@ -222,8 +222,8 @@ def training_step(self, data_batch, batch_nb): ```python # define what happens for validation here -def validation_step(self, data_batch, batch_nb): - x, y = data_batch +def validation_step(self, batch, batch_nb): + x, y = batch # or as basic as a CNN classification out = self.forward(x) @@ -248,8 +248,8 @@ def validation_end(self, outputs): val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + return tqdm_dict ``` ## Tensorboard diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index c8ba29c0b7..d8ccae94f3 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -10,7 +10,7 @@ Otherwise, to Define a Lightning Module, implement the following methods: **Required**: - [training_step](RequiredTrainerInterface.md#training_step) -- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) +- [train_dataloader](RequiredTrainerInterface.md#train_dataloader) - [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers) **Optional**: @@ -23,7 +23,7 @@ Otherwise, to Define a Lightning Module, implement the following methods: - [test_dataloader](RequiredTrainerInterface.md#test_dataloader) - [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint) - [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint) -- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics) +- [update_training_log_metrics](RequiredTrainerInterface.md#update_training_log_metrics) - [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args) --- @@ -81,7 +81,7 @@ class CoolModel(pl.LightningModule): return [torch.optim.Adam(self.parameters(), lr=0.02)] @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader @@ -111,7 +111,7 @@ The LightningModule interface is on the right. Each method corresponds to a part ### training_step ``` {.python} -def training_step(self, data_batch, batch_nb) +def training_step(self, batch, batch_nb) ``` In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model. @@ -120,7 +120,7 @@ In this step you'd normally do the forward pass and calculate the loss for a bat | Param | description | |---|---| -| data_batch | The output of your dataloader. A tensor, tuple or list | +| batch | The output of your dataloader. A tensor, tuple or list | | batch_nb | Integer displaying which batch this is | **Return** @@ -130,14 +130,14 @@ Dictionary or OrderedDict | key | value | is required | |---|---|---| | loss | tensor scalar | Y | -| prog | Dict for progress bar display. Must have only tensors | N | +| progress | Dict for progress bar display. Must have only tensors | N | **Example** ``` {.python} -def training_step(self, data_batch, batch_nb): - x, y, z = data_batch +def training_step(self, batch, batch_nb): + x, y, z = batch # implement your own out = self.forward(x) @@ -145,7 +145,7 @@ def training_step(self, data_batch, batch_nb): output = { 'loss': loss, # required - 'prog': {'tng_loss': loss, 'batch_nb': batch_nb} # optional + 'progress': {'training_loss': loss, 'batch_nb': batch_nb} # optional } # return a dict @@ -155,7 +155,7 @@ def training_step(self, data_batch, batch_nb): If you define multiple optimizers, this step will also be called with an additional ```optimizer_idx``` param. ``` {.python} # Multiple optimizers (ie: GANs) -def training_step(self, data_batch, batch_nb, optimizer_idx): +def training_step(self, batch, batch_nb, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder if optimizer_idx == 1: @@ -163,11 +163,11 @@ def training_step(self, data_batch, batch_nb, optimizer_idx): ``` --- -### tng_dataloader +### train_dataloader ``` {.python} @pl.data_loader -def tng_dataloader(self) +def train_dataloader(self) ``` Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed. @@ -178,7 +178,7 @@ PyTorch DataLoader ``` {.python} @pl.data_loader -def tng_dataloader(self): +def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( @@ -240,10 +240,10 @@ the [optimizer_step](https://williamfalcon.github.io/pytorch-lightning/Trainer/h ``` {.python} # if you have one val dataloader: -def validation_step(self, data_batch, batch_nb) +def validation_step(self, batch, batch_nb) # if you have multiple val dataloaders: -def validation_step(self, data_batch, batch_nb, dataloader_idx) +def validation_step(self, batch, batch_nb, dataloader_idxdx) ``` **OPTIONAL** If you don't need to validate you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy. @@ -256,9 +256,9 @@ The dict you return here will be available in the `validation_end` method. | Param | description | |---|---| -| data_batch | The output of your dataloader. A tensor, tuple or list | +| batch | The output of your dataloader. A tensor, tuple or list | | batch_nb | Integer displaying which batch this is | -| dataloader_i | Integer displaying which dataloader this is (only if multiple val datasets used) | +| dataloader_idx | Integer displaying which dataloader this is (only if multiple val datasets used) | **Return** @@ -270,8 +270,8 @@ The dict you return here will be available in the `validation_end` method. ``` {.python} # CASE 1: A single validation dataset -def validation_step(self, data_batch, batch_nb): - x, y = data_batch +def validation_step(self, batch, batch_nb): + x, y = batch # implement your own out = self.forward(x) @@ -302,7 +302,7 @@ If you pass in multiple validation datasets, validation_step will have an additi ```python # CASE 2: multiple validation datasets -def validation_step(self, data_batch, batch_nb, dataset_idx): +def validation_step(self, batch, batch_nb, dataset_idx): # dataset_idx tells you which dataset this is. ``` @@ -351,8 +351,8 @@ def validation_end(self, outputs): val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + return tqdm_dict ``` With multiple dataloaders, `outputs` will be a list of lists. The outer list contains @@ -377,18 +377,18 @@ def validation_end(self, outputs): val_loss_mean /= i val_acc_mean /= i - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + return tqdm_dict ``` ### test_step ``` {.python} # if you have one test dataloader: -def test_step(self, data_batch, batch_nb) +def test_step(self, batch, batch_nb) # if you have multiple test dataloaders: -def test_step(self, data_batch, batch_nb, dataloader_idx) +def test_step(self, batch, batch_nb, dataloader_idxdx) ``` **OPTIONAL** If you don't need to test you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy. @@ -403,9 +403,9 @@ This function is used when you execute `trainer.test()`. | Param | description | |---|---| -| data_batch | The output of your dataloader. A tensor, tuple or list | +| batch | The output of your dataloader. A tensor, tuple or list | | batch_nb | Integer displaying which batch this is | -| dataloader_i | Integer displaying which dataloader this is (only if multiple test datasets used) | +| dataloader_idx | Integer displaying which dataloader this is (only if multiple test datasets used) | **Return** @@ -417,8 +417,8 @@ This function is used when you execute `trainer.test()`. ``` {.python} # CASE 1: A single test dataset -def test_step(self, data_batch, batch_nb): - x, y = data_batch +def test_step(self, batch, batch_nb): + x, y = batch # implement your own out = self.forward(x) @@ -443,7 +443,7 @@ If you pass in multiple test datasets, test_step will have an additional argumen ```python # CASE 2: multiple test datasets -def test_step(self, data_batch, batch_nb, dataset_idx): +def test_step(self, batch, batch_nb, dataset_idx): # dataset_idx tells you which dataset this is. ``` @@ -490,8 +490,8 @@ def test_end(self, outputs): test_loss_mean /= len(outputs) test_acc_mean /= len(outputs) - tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + return tqdm_dict ``` With multiple dataloaders, `outputs` will be a list of lists. The outer list contains @@ -516,8 +516,8 @@ def test_end(self, outputs): test_loss_mean /= i test_acc_mean /= i - tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + return tqdm_dict ``` --- @@ -633,10 +633,10 @@ def test_dataloader(self): ``` --- -### update_tng_log_metrics +### update_training_log_metrics ``` {.python} -def update_tng_log_metrics(self, logs) +def update_training_log_metrics(self, logs) ``` Called by lightning right before it logs metrics for this batch. This is a chance to amend or add to the metrics about to be logged. @@ -647,7 +647,7 @@ Dict **Example** ``` {.python} -def update_tng_log_metrics(self, logs): +def update_training_log_metrics(self, logs): # modify or add to logs return logs ``` @@ -674,7 +674,7 @@ def add_model_specific_args(parent_parser, root_dir): parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) # param overwrites - # parser.set_defaults(gradient_clip=5.0) + # parser.set_defaults(gradient_clip_val=5.0) # network params parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) diff --git a/docs/LightningModule/properties.md b/docs/LightningModule/properties.md index 54883f3fd6..f8730ac951 100644 --- a/docs/LightningModule/properties.md +++ b/docs/LightningModule/properties.md @@ -22,7 +22,7 @@ self.experiment.add_scalars(...) Total training batches seen across all epochs --- -#### gradient_clip +#### gradient_clip_val The current gradient clip value --- diff --git a/docs/Trainer/Logging.md b/docs/Trainer/Logging.md index 27d1e85d24..85cdfce159 100644 --- a/docs/Trainer/Logging.md +++ b/docs/Trainer/Logging.md @@ -13,7 +13,7 @@ trainer = Trainer(show_progress_bar=True) Every k batches lightning will make an entry in the metrics log ``` {.python} # DEFAULT (ie: save a .csv log file every 10 batches) -trainer = Trainer(add_log_row_interval=10) +trainer = Trainer(row_log_interval=10) ``` --- diff --git a/docs/Trainer/Training Loop.md b/docs/Trainer/Training Loop.md index 2779f869a4..2193e4dac5 100644 --- a/docs/Trainer/Training Loop.md +++ b/docs/Trainer/Training Loop.md @@ -52,10 +52,10 @@ Specifically, this will [clip the gradient norm computed over all model paramete ``` {.python} # DEFAULT (ie: don't clip) -trainer = Trainer(gradient_clip=0) +trainer = Trainer(gradient_clip_val=0) # clip gradients with norm above 0.5 -trainer = Trainer(gradient_clip=0.5) +trainer = Trainer(gradient_clip_val=0.5) ``` --- diff --git a/docs/Trainer/hooks.md b/docs/Trainer/hooks.md index a5a776d76e..e18230169c 100644 --- a/docs/Trainer/hooks.md +++ b/docs/Trainer/hooks.md @@ -58,12 +58,12 @@ def on_post_performance_check(self): ``` --- -#### on_tng_metrics +#### on_training_metrics Called in the training loop, right before metrics are logged. Although you can log at any time by using self.experiment, you can use this callback to modify what will be logged. ```python -def on_tng_metrics(self, metrics): +def on_training_metrics(self, metrics): # do something before validation end ``` diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 6cdd7677ad..3be61cf370 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -79,14 +79,14 @@ class LightningTemplateModel(LightningModule): nll = F.nll_loss(logits, labels) return nll - def training_step(self, data_batch, batch_i): + def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop - :param data_batch: + :param batch: :return: """ # forward pass - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -105,13 +105,13 @@ class LightningTemplateModel(LightningModule): # can also return just a scalar instead of a dict (return loss_val) return output - def validation_step(self, data_batch, batch_i): + def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop - :param data_batch: + :param batch: :return: """ - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -167,8 +167,8 @@ class LightningTemplateModel(LightningModule): val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} + return tqdm_dict # --------------------- # TRAINING SETUP @@ -208,8 +208,8 @@ class LightningTemplateModel(LightningModule): return loader @pl.data_loader - def tng_dataloader(self): - print('tng data loader called') + def train_dataloader(self): + print('training data loader called') return self.__dataloader(train=True) @pl.data_loader @@ -233,7 +233,7 @@ class LightningTemplateModel(LightningModule): parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) # param overwrites - # parser.set_defaults(gradient_clip=5.0) + # parser.set_defaults(gradient_clip_val=5.0) # network params parser.add_argument('--in_features', default=28 * 28, type=int) diff --git a/examples/templates/gan.py b/examples/templates/gan.py index a7dea95663..d15bdb5b2d 100644 --- a/examples/templates/gan.py +++ b/examples/templates/gan.py @@ -146,7 +146,7 @@ class GAN(pl.LightningModule): return [opt_g, opt_d], [] @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 12a187c808..6820fd70b4 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -10,7 +10,7 @@ class ModelHooks(torch.nn.Module): """ pass - def on_batch_start(self, data_batch): + def on_batch_start(self, batch): pass def on_batch_end(self): @@ -28,7 +28,7 @@ class ModelHooks(torch.nn.Module): def on_post_performance_check(self): pass - def on_tng_metrics(self, metrics): + def on_training_metrics(self, metrics): pass def on_before_zero_grad(self, optimizer): diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 55a9897b91..0e0259635d 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -106,7 +106,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): optimizer.zero_grad() @data_loader - def tng_dataloader(self): + def train_dataloader(self): """ Implement a PyTorch DataLoader :return: diff --git a/pytorch_lightning/testing/lm_test_module.py b/pytorch_lightning/testing/lm_test_module.py index c3b7d9b4cf..e89011632b 100644 --- a/pytorch_lightning/testing/lm_test_module.py +++ b/pytorch_lightning/testing/lm_test_module.py @@ -23,5 +23,5 @@ class LightningTestModel(LightningValidationMixin, LightningTestMixin, Lightning Most common test case. Validation and test dataloaders """ - def on_tng_metrics(self, logs): + def on_training_metrics(self, logs): logs['some_tensor_to_test'] = torch.rand(1) diff --git a/pytorch_lightning/testing/lm_test_module_base.py b/pytorch_lightning/testing/lm_test_module_base.py index 47350d0c05..ab7682411b 100644 --- a/pytorch_lightning/testing/lm_test_module_base.py +++ b/pytorch_lightning/testing/lm_test_module_base.py @@ -81,14 +81,14 @@ class LightningTestModelBase(LightningModule): nll = F.nll_loss(logits, labels) return nll - def training_step(self, data_batch, batch_i): + def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop - :param data_batch: + :param batch: :return: """ # forward pass - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -104,7 +104,7 @@ class LightningTestModelBase(LightningModule): if self.trainer.batch_nb % 1 == 0: output = OrderedDict({ 'loss': loss_val, - 'prog': {'some_val': loss_val * loss_val} + 'progress': {'some_val': loss_val * loss_val} }) return output if self.trainer.batch_nb % 2 == 0: @@ -153,7 +153,7 @@ class LightningTestModelBase(LightningModule): return loader @data_loader - def tng_dataloader(self): + def train_dataloader(self): return self._dataloader(train=True) @staticmethod @@ -167,7 +167,7 @@ class LightningTestModelBase(LightningModule): parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) # param overwrites - # parser.set_defaults(gradient_clip=5.0) + # parser.set_defaults(gradient_clip_val=5.0) # network params parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) diff --git a/pytorch_lightning/testing/lm_test_module_mixins.py b/pytorch_lightning/testing/lm_test_module_mixins.py index 8c6bd020ba..3831300e24 100644 --- a/pytorch_lightning/testing/lm_test_module_mixins.py +++ b/pytorch_lightning/testing/lm_test_module_mixins.py @@ -25,13 +25,13 @@ class LightningValidationStepMixin: def val_dataloader(self): return self._dataloader(train=False) - def validation_step(self, data_batch, batch_i): + def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop - :param data_batch: + :param batch: :return: """ - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -51,16 +51,16 @@ class LightningValidationStepMixin: val_acc = val_acc.unsqueeze(0) # alternate possible outputs to test - if batch_i % 1 == 0: + if batch_idx % 1 == 0: output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, }) return output - if batch_i % 2 == 0: + if batch_idx % 2 == 0: return val_acc - if batch_i % 3 == 0: + if batch_idx % 3 == 0: output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, @@ -104,8 +104,8 @@ class LightningValidationMixin(LightningValidationStepMixin): val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + return tqdm_dict class LightningValidationStepMultipleDataloadersMixin: @@ -118,13 +118,13 @@ class LightningValidationStepMultipleDataloadersMixin: def val_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] - def validation_step(self, data_batch, batch_i, dataloader_i): + def validation_step(self, batch, batch_idx, dataloader_idx): """ Lightning calls this inside the validation loop - :param data_batch: + :param batch: :return: """ - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -144,26 +144,26 @@ class LightningValidationStepMultipleDataloadersMixin: val_acc = val_acc.unsqueeze(0) # alternate possible outputs to test - if batch_i % 1 == 0: + if batch_idx % 1 == 0: output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, }) return output - if batch_i % 2 == 0: + if batch_idx % 2 == 0: return val_acc - if batch_i % 3 == 0: + if batch_idx % 3 == 0: output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, 'test_dic': {'val_loss_a': loss_val} }) return output - if batch_i % 5 == 0: + if batch_idx % 5 == 0: output = OrderedDict({ - f'val_loss_{dataloader_i}': loss_val, - f'val_acc_{dataloader_i}': val_acc, + f'val_loss_{dataloader_idx}': loss_val, + f'val_acc_{dataloader_idx}': val_acc, }) return output @@ -206,8 +206,8 @@ class LightningValidationMultipleDataloadersMixin(LightningValidationStepMultipl val_loss_mean /= i val_acc_mean /= i - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + return tqdm_dict class LightningTestStepMixin: @@ -216,13 +216,13 @@ class LightningTestStepMixin: def test_dataloader(self): return self._dataloader(train=False) - def test_step(self, data_batch, batch_i): + def test_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop - :param data_batch: + :param batch: :return: """ - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -242,16 +242,16 @@ class LightningTestStepMixin: test_acc = test_acc.unsqueeze(0) # alternate possible outputs to test - if batch_i % 1 == 0: + if batch_idx % 1 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, }) return output - if batch_i % 2 == 0: + if batch_idx % 2 == 0: return test_acc - if batch_i % 3 == 0: + if batch_idx % 3 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, @@ -290,8 +290,8 @@ class LightningTestMixin(LightningTestStepMixin): test_loss_mean /= len(outputs) test_acc_mean /= len(outputs) - tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + return tqdm_dict class LightningTestStepMultipleDataloadersMixin: @@ -300,13 +300,13 @@ class LightningTestStepMultipleDataloadersMixin: def test_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] - def test_step(self, data_batch, batch_i, dataloader_i): + def test_step(self, batch, batch_idx, dataloader_idx): """ Lightning calls this inside the validation loop - :param data_batch: + :param batch: :return: """ - x, y = data_batch + x, y = batch x = x.view(x.size(0), -1) y_hat = self.forward(x) @@ -326,26 +326,26 @@ class LightningTestStepMultipleDataloadersMixin: test_acc = test_acc.unsqueeze(0) # alternate possible outputs to test - if batch_i % 1 == 0: + if batch_idx % 1 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, }) return output - if batch_i % 2 == 0: + if batch_idx % 2 == 0: return test_acc - if batch_i % 3 == 0: + if batch_idx % 3 == 0: output = OrderedDict({ 'test_loss': loss_test, 'test_acc': test_acc, 'test_dic': {'test_loss_a': loss_test} }) return output - if batch_i % 5 == 0: + if batch_idx % 5 == 0: output = OrderedDict({ - f'test_loss_{dataloader_i}': loss_test, - f'test_acc_{dataloader_i}': test_acc, + f'test_loss_{dataloader_idx}': loss_test, + f'test_acc_{dataloader_idx}': test_acc, }) return output @@ -383,5 +383,5 @@ class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloaders test_loss_mean /= i test_acc_mean /= i - tqdm_dic = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} - return tqdm_dic + tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()} + return tqdm_dict diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a01bae69b6..6f8bd6005f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -57,7 +57,7 @@ class Trainer(TrainerIO): experiment=None, early_stop_callback=None, checkpoint_callback=None, - gradient_clip=0, + gradient_clip_val=0, process_position=0, nb_gpu_nodes=1, gpus=None, @@ -75,7 +75,7 @@ class Trainer(TrainerIO): test_percent_check=1.0, val_check_interval=1.0, log_save_interval=100, - add_log_row_interval=10, + row_log_interval=10, distributed_backend=None, use_amp=False, print_nan_grads=False, @@ -88,7 +88,7 @@ class Trainer(TrainerIO): :param experiment: Test-tube experiment :param early_stop_callback: Callback for early stopping :param checkpoint_callback: Callback for checkpointing - :param gradient_clip: int. 0 means don't clip. + :param gradient_clip_val: int. 0 means don't clip. :param process_position: shown in the tqdm bar :param nb_gpu_nodes: number of GPU nodes :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1' @@ -106,7 +106,7 @@ class Trainer(TrainerIO): :param test_percent_check: int. How much of test set to check :param val_check_interval: int. Check val this frequently within a train epoch :param log_save_interval: int. Writes logs to disk this often - :param add_log_row_interval: int. How often to add logging rows + :param row_log_interval: int. How often to add logging rows :param distributed_backend: str. dp, or ddp. :param use_amp: Bool. If true uses apex for 16bit precision :param print_nan_grads: Bool. Prints nan gradients @@ -118,7 +118,7 @@ class Trainer(TrainerIO): # Transfer params self.nb_gpu_nodes = nb_gpu_nodes self.log_gpu_memory = log_gpu_memory - self.gradient_clip = gradient_clip + self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch self.enable_early_stop = early_stop_callback is not None self.track_grad_norm = track_grad_norm @@ -138,9 +138,9 @@ class Trainer(TrainerIO): self.batch_nb = 0 self.tqdm_metrics = {} self.nb_val_batches = 0 - self.nb_tng_batches = 0 + self.nb_training_batches = 0 self.nb_test_batches = 0 - self.tng_dataloader = None + self.train_dataloader = None self.test_dataloader = None self.val_dataloader = None @@ -188,13 +188,13 @@ class Trainer(TrainerIO): self.__set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process - # means the prog_bar won't survive pickling + # means the progress_bar won't survive pickling self.show_progress_bar = show_progress_bar # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval - self.add_log_row_interval = add_log_row_interval + self.row_log_interval = row_log_interval # how much of the data to use self.__determine_data_use_amount(train_percent_check, val_percent_check, @@ -405,36 +405,36 @@ class Trainer(TrainerIO): return is_overriden @property - def __tng_tqdm_dic(self): - tqdm_dic = { + def __training_tqdm_dict(self): + tqdm_dict = { 'loss': '{0:.3f}'.format(self.avg_loss), 'epoch': '{}'.format(self.current_epoch), 'batch_nb': '{}'.format(self.batch_nb), } if self.experiment is not None: - tqdm_dic['v_nb'] = self.experiment.version + tqdm_dict['v_nb'] = self.experiment.version - tqdm_dic.update(self.tqdm_metrics) + tqdm_dict.update(self.tqdm_metrics) if self.on_gpu: - tqdm_dic['gpu'] = '{}'.format(torch.cuda.current_device()) + tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device()) - return tqdm_dic + return tqdm_dict @property - def tng_tqdm_dic(self): + def training_tqdm_dict(self): """ Read-only for tqdm metrics :return: """ - return self.__tng_tqdm_dic + return self.__training_tqdm_dict def __layout_bookeeping(self): # determine number of training batches - self.nb_tng_batches = len(self.tng_dataloader) - self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) + self.nb_training_batches = len(self.train_dataloader) + self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check) # determine number of validation batches # val datasets could be none, 1 or 2+ @@ -450,7 +450,7 @@ class Trainer(TrainerIO): self.nb_test_batches = max(1, self.nb_test_batches) # determine when to check validation - self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval) + self.val_check_batch = int(self.nb_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) def __add_tqdm_metrics(self, metrics): @@ -460,15 +460,15 @@ class Trainer(TrainerIO): self.tqdm_metrics[k] = v - def __evaluation_forward(self, model, data_batch, batch_i, dataloader_i, test=False): - # make dataloader_i arg in validation_step optional - args = [data_batch, batch_i] + def __evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] if test and len(self.test_dataloader) > 1: - args.append(dataloader_i) + args.append(dataloader_idx) elif not test and len(self.val_dataloader) > 1: - args.append(dataloader_i) + args.append(dataloader_idx) # handle DP, DDP forward if self.use_ddp or self.use_dp: @@ -481,8 +481,8 @@ class Trainer(TrainerIO): root_gpu = 0 if type(self.data_parallel_device_ids) is list: root_gpu = self.data_parallel_device_ids[0] - data_batch = self.transfer_batch_to_gpu(data_batch, root_gpu) - args[0] = data_batch + batch = self.transfer_batch_to_gpu(batch, root_gpu) + args[0] = batch if test: output = model.test_step(*args) @@ -497,7 +497,7 @@ class Trainer(TrainerIO): :param model: PT model :param dataloaders: list of PT dataloaders :param max_batches: Scalar - :param dataloader_i: + :param dataloader_idx: :param test: boolean :return: """ @@ -512,21 +512,21 @@ class Trainer(TrainerIO): outputs = [] # run training - for dataloader_i, dl in enumerate(dataloaders): + for dataloader_idx, dl in enumerate(dataloaders): dl_outputs = [] - for batch_i, data_batch in enumerate(dl): + for batch_idx, batch in enumerate(dl): - if data_batch is None: # pragma: no cover + if batch is None: # pragma: no cover continue # stop short when on fast_dev_run (sets max_batch=1) - if batch_i >= max_batches: + if batch_idx >= max_batches: break # ----------------- # RUN EVALUATION STEP # ----------------- - output = self.__evaluation_forward(model, data_batch, batch_i, dataloader_i, + output = self.__evaluation_forward(model, batch, batch_idx, dataloader_idx, test) # track outputs for collation @@ -563,7 +563,7 @@ class Trainer(TrainerIO): :return: """ - self.tng_dataloader = model.tng_dataloader + self.train_dataloader = model.train_dataloader self.test_dataloader = model.test_dataloader self.val_dataloader = model.val_dataloader @@ -576,7 +576,7 @@ class Trainer(TrainerIO): if have_val_loaders and not isinstance(self.val_dataloader, list): self.val_dataloader = [self.val_dataloader] - if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler): + if self.use_ddp and not isinstance(self.train_dataloader.sampler, DistributedSampler): msg = """ You're using multiple gpus and multiple nodes without using a DistributedSampler to assign a subset of your data to each process. To silence this warning, pass a @@ -767,7 +767,7 @@ class Trainer(TrainerIO): except Exception: self.node_rank = 0 - # show progbar only on prog_rank 0 + # show progressbar only on progress_rank 0 self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_nb == 0 # determine which process we are and world size @@ -929,7 +929,7 @@ class Trainer(TrainerIO): for epoch_nb in range(self.current_epoch, self.max_nb_epochs): # set seed for distributed sampler (enables shuffling for each epoch) if self.use_ddp: - self.tng_dataloader.sampler.set_epoch(epoch_nb) + self.train_dataloader.sampler.set_epoch(epoch_nb) # get model model = self.__get_model() @@ -937,7 +937,7 @@ class Trainer(TrainerIO): # update training progress in trainer and model model.current_epoch = epoch_nb self.current_epoch = epoch_nb - self.total_batches = self.nb_tng_batches + self.nb_val_batches + self.total_batches = self.nb_training_batches + self.nb_val_batches self.batch_loss_value = 0 # accumulated grads # init progress_bar when requested @@ -950,7 +950,7 @@ class Trainer(TrainerIO): # ----------------- # RUN TNG EPOCH # ----------------- - self.run_tng_epoch() + self.run_training_epoch() # update LR schedulers if self.lr_schedulers is not None: @@ -961,20 +961,20 @@ class Trainer(TrainerIO): met_min_epochs = epoch_nb > self.min_nb_epochs if self.enable_early_stop and met_min_epochs: should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, - logs=self.__tng_tqdm_dic) + logs=self.__training_tqdm_dict) # stop training stop = should_stop and met_min_epochs if stop: return - def run_tng_epoch(self): + def run_training_epoch(self): # before epoch hook if self.__is_function_implemented('on_epoch_start'): model = self.__get_model() model.on_epoch_start() # run epoch - for batch_nb, data_batch in enumerate(self.tng_dataloader): + for batch_nb, batch in enumerate(self.train_dataloader): self.batch_nb = batch_nb self.global_step += 1 @@ -984,14 +984,14 @@ class Trainer(TrainerIO): # stop when the flag is changed or we've gone past the amount # requested in the batches self.total_batch_nb += 1 - met_batch_limit = batch_nb > self.nb_tng_batches + met_batch_limit = batch_nb > self.nb_training_batches if met_batch_limit: break # --------------- # RUN TRAIN STEP # --------------- - batch_result = self.__run_tng_batch(data_batch, batch_nb) + batch_result = self.__run_training_batch(batch, batch_nb) early_stop_epoch = batch_result == -1 # --------------- @@ -1009,12 +1009,12 @@ class Trainer(TrainerIO): self.experiment.save() # when metrics should be logged - if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch: + if batch_nb % self.row_log_interval == 0 or early_stop_epoch: # count items in memory # nb_params, nb_tensors = count_mem_items() model = self.__get_model() - metrics = self.__tng_tqdm_dic + metrics = self.__training_tqdm_dict # add gpu memory if self.on_gpu and self.log_gpu_memory: @@ -1027,8 +1027,8 @@ class Trainer(TrainerIO): grad_norm_dic = model.grad_norm(self.track_grad_norm) metrics.update(grad_norm_dic) - if self.__is_function_implemented('on_tng_metrics'): - model.on_tng_metrics(metrics) + if self.__is_function_implemented('on_training_metrics'): + model.on_training_metrics(metrics) # log metrics scalar_metrics = self.__metrics_to_scalars( @@ -1103,10 +1103,10 @@ class Trainer(TrainerIO): # nothing matches, return the value as is without transform return batch - def __tng_forward(self, data_batch, batch_nb, opt_idx): + def __training_forward(self, batch, batch_nb, opt_idx): """ Handle forward for each training case (distributed, single gpu, etc...) - :param data_batch: + :param batch: :param batch_nb: :return: """ @@ -1114,7 +1114,7 @@ class Trainer(TrainerIO): # FORWARD # --------------- # enable not needing to add opt_idx to training_step - args = [data_batch, batch_nb] + args = [batch, batch_nb] if len(self.optimizers) > 1: args.append(opt_idx) @@ -1126,8 +1126,8 @@ class Trainer(TrainerIO): gpu_id = 0 if type(self.data_parallel_device_ids) is list: gpu_id = self.data_parallel_device_ids[0] - data_batch = self.transfer_batch_to_gpu(data_batch, gpu_id) - args[0] = data_batch + batch = self.transfer_batch_to_gpu(batch, gpu_id) + args[0] = batch output = self.model.training_step(*args) else: @@ -1137,14 +1137,14 @@ class Trainer(TrainerIO): # TQDM metrics # --------------- try: - prog_output = output['prog'] + progress_output = output['progress'] - # reduce prog metrics for tqdm when using dp + # reduce progress metrics for tqdm when using dp if self.use_dp: nb_gpus = self.num_gpus - prog_output = reduce_distributed_output(prog_output, nb_gpus) + progress_output = reduce_distributed_output(progress_output, nb_gpus) - model_specific_tqdm_metrics_dic = prog_output + model_specific_tqdm_metrics_dic = progress_output except Exception: model_specific_tqdm_metrics_dic = {} @@ -1166,9 +1166,9 @@ class Trainer(TrainerIO): return loss, model_specific_tqdm_metrics_dic def __clip_gradients(self): - if self.gradient_clip > 0: + if self.gradient_clip_val > 0: model = self.__get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) def __print_nan_grads(self): model = self.__get_model() @@ -1176,14 +1176,14 @@ class Trainer(TrainerIO): if torch.isnan(param.grad.float()).any(): print(param, param.grad) - def __run_tng_batch(self, data_batch, batch_nb): - if data_batch is None: + def __run_training_batch(self, batch, batch_nb): + if batch is None: return 0 # hook if self.__is_function_implemented('on_batch_start'): model_ref = self.__get_model() - response = model_ref.on_batch_start(data_batch) + response = model_ref.on_batch_start(batch) if response == -1: return -1 @@ -1195,7 +1195,7 @@ class Trainer(TrainerIO): for opt_idx, optimizer in enumerate(self.optimizers): # forward pass - loss, model_specific_tqdm_metrics = self.__tng_forward(data_batch, batch_nb, opt_idx) + loss, model_specific_tqdm_metrics = self.__training_forward(batch, batch_nb, opt_idx) # track metrics self.__add_tqdm_metrics(model_specific_tqdm_metrics) @@ -1238,10 +1238,10 @@ class Trainer(TrainerIO): self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:]) - # update progbar + # update progressbar if self.show_progress_bar: # add model specific metrics - tqdm_metrics = self.__tng_tqdm_dic + tqdm_metrics = self.__training_tqdm_dict self.progress_bar.set_postfix(**tqdm_metrics) # activate batch end hook @@ -1296,11 +1296,11 @@ class Trainer(TrainerIO): if self.show_progress_bar: # add model specific metrics - tqdm_metrics = self.__tng_tqdm_dic + tqdm_metrics = self.__training_tqdm_dict self.progress_bar.set_postfix(**tqdm_metrics) # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: print('save callback...') self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, - logs=self.__tng_tqdm_dic) + logs=self.__training_tqdm_dict) diff --git a/pytorch_lightning/utilities/arg_parse.py b/pytorch_lightning/utilities/arg_parse.py index 39d4ec81f9..fbf3ec2f19 100644 --- a/pytorch_lightning/utilities/arg_parse.py +++ b/pytorch_lightning/utilities/arg_parse.py @@ -8,7 +8,7 @@ import os def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None): - # tng, test, val check intervals + # training, test, val check intervals parser.add_argument('--eval_test_set', dest='eval_test_set', action='store_true', help='true = run test set also') parser.add_argument('--check_val_every_n_epoch', default=1, type=int, @@ -19,7 +19,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None parser.add_argument('--max_nb_epochs', default=200, type=int, help='cap epochs') parser.add_argument('--min_nb_epochs', default=2, type=int, help='min epochs') parser.add_argument('--train_percent_check', default=1.0, type=float, - help='how much of tng set to check') + help='how much of training set to check') parser.add_argument('--val_percent_check', default=1.0, type=float, help='how much of val set to check') parser.add_argument('--test_percent_check', default=1.0, type=float, @@ -29,7 +29,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None help='how much within 1 epoch to check val') parser.add_argument('--log_save_interval', default=100, type=int, help='how many batches between log saves') - parser.add_argument('--add_log_row_interval', default=100, type=int, + parser.add_argument('--row_log_interval', default=100, type=int, help='add log every k batches') # early stopping @@ -40,7 +40,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None help='number of epochs until stop') # gradient handling - parser.add_argument('--gradient_clip', default=-1, type=int) + parser.add_argument('--gradient_clip_val', default=-1, type=int) parser.add_argument('--track_grad_norm', default=-1, type=int, help='if > 0, will track this grad norm') @@ -78,9 +78,9 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None # FAST training # use these settings to make sure network has no bugs without running a full dataset parser.add_argument('--fast_dev_run', dest='fast_dev_run', default=False, action='store_true', - help='runs validation after 1 tng step') + help='runs validation after 1 training step') parser.add_argument('--enable_tqdm', dest='enable_tqdm', default=False, action='store_true', - help='false removes the prog bar') + help='false removes the progress bar') parser.add_argument('--overfit', default=-1, type=float, help='% of dataset to use with this option. float, or -1 for none') @@ -93,7 +93,7 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None parser.add_argument('--debug', dest='debug', action='store_true', help='enables/disables test tube') parser.add_argument('--local', dest='local', action='store_true', - help='enables local tng') + help='enables local training') # optimizer parser.add_argument('--lr_scheduler_milestones', default=None, type=str) diff --git a/tests/debug.py b/tests/debug.py index 6016dee3ba..db87f2ed6d 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -32,7 +32,7 @@ class CoolModel(pl.LightningModule): def training_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'tng_loss': self.my_loss(y_hat, y)} + return {'training_loss': self.my_loss(y_hat, y)} def validation_step(self, batch, batch_nb): x, y = batch @@ -47,7 +47,7 @@ class CoolModel(pl.LightningModule): return [torch.optim.Adam(self.parameters(), lr=0.02)] @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): return DataLoader(MNIST('path/to/save', train=True), batch_size=32) @pl.data_loader @@ -182,13 +182,13 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): def assert_ok_val_acc(trainer): # this model should get 0.80+ acc - acc = trainer.tng_tqdm_dic['val_acc'] + acc = trainer.training_tqdm_dict['val_acc'] assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' def assert_ok_test_acc(trainer): # this model should get 0.80+ acc - acc = trainer.tng_tqdm_dic['test_acc'] + acc = trainer.training_tqdm_dict['test_acc'] assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' diff --git a/tests/test_models.py b/tests/test_models.py index a5c463dcc7..3c82492a4a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -511,7 +511,7 @@ def test_early_stopping_cpu_model(): stopping = EarlyStopping(monitor='val_loss') trainer_options = dict( early_stop_callback=stopping, - gradient_clip=1.0, + gradient_clip_val=1.0, overfit_pct=0.20, track_grad_norm=2, print_nan_grads=True, @@ -1098,7 +1098,7 @@ def test_all_features_cpu_model(): """ trainer_options = dict( - gradient_clip=1.0, + gradient_clip_val=1.0, overfit_pct=0.20, track_grad_norm=2, print_nan_grads=True, @@ -1256,7 +1256,7 @@ def test_multiple_val_dataloader(): trainer = Trainer(**trainer_options) result = trainer.fit(model) - # verify tng completed + # verify training completed assert result == 1 # verify there are 2 val loaders @@ -1450,13 +1450,13 @@ def run_prediction(dataloader, trained_model, dp=False): def assert_ok_val_acc(trainer): # this model should get 0.80+ acc - acc = trainer.tng_tqdm_dic['val_acc'] + acc = trainer.training_tqdm_dict['val_acc'] assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}' def assert_ok_test_acc(trainer): # this model should get 0.80+ acc - acc = trainer.tng_tqdm_dic['test_acc'] + acc = trainer.training_tqdm_dict['test_acc'] assert acc > 0.50, f'model failed to get expected 0.50 validation accuracy. Got: {acc}'