From 8531f33549b8a3301dd59436c1ca2c0772682fce Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 14 May 2019 06:36:26 -0400 Subject: [PATCH] tng and val steps now have batch nbs --- pytorch_lightning/models/trainer.py | 10 +++++----- setup.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index dd3aed2b6e..6c71a94b6f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -163,7 +163,7 @@ class Trainer(TrainerIO): outputs = [] # run training - for i, data_batch in enumerate(dataloader): + for batch_i, data_batch in enumerate(dataloader): if data_batch is None: continue @@ -175,7 +175,7 @@ class Trainer(TrainerIO): # ----------------- # RUN VALIDATION STEP # ----------------- - output = model.validation_step(data_batch) + output = model.validation_step(data_batch, batch_i) outputs.append(output) # batch done @@ -290,7 +290,7 @@ class Trainer(TrainerIO): # --------------- # RUN TRAIN STEP # --------------- - batch_result = self.__run_tng_batch(data_batch) + batch_result = self.__run_tng_batch(data_batch, batch_nb) early_stop_epoch = batch_result == -1 # --------------- @@ -348,7 +348,7 @@ class Trainer(TrainerIO): return - def __run_tng_batch(self, data_batch): + def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: return 0 @@ -363,7 +363,7 @@ class Trainer(TrainerIO): # forward pass # return a scalar value and a dic with tqdm metrics - loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch) + loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb) self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass diff --git a/setup.py b/setup.py index b9c723f180..94b71c20c5 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages # http://blog.ionelmc.ro/2014/05/25/python-packaging/ setup( name="pytorch-lightning", - version='0.1.dev1821', + version='0.1.dev1822', description="The Keras for ML researchers using PyTorch", author="William Falcon", author_email="waf2107@columbia.edu",