tng and val steps now have batch nbs

This commit is contained in:
William Falcon 2019-05-14 06:36:26 -04:00
parent 98b26c5c7e
commit 8531f33549
2 changed files with 6 additions and 6 deletions

View File

@ -163,7 +163,7 @@ class Trainer(TrainerIO):
outputs = []
# run training
for i, data_batch in enumerate(dataloader):
for batch_i, data_batch in enumerate(dataloader):
if data_batch is None:
continue
@ -175,7 +175,7 @@ class Trainer(TrainerIO):
# -----------------
# RUN VALIDATION STEP
# -----------------
output = model.validation_step(data_batch)
output = model.validation_step(data_batch, batch_i)
outputs.append(output)
# batch done
@ -290,7 +290,7 @@ class Trainer(TrainerIO):
# ---------------
# RUN TRAIN STEP
# ---------------
batch_result = self.__run_tng_batch(data_batch)
batch_result = self.__run_tng_batch(data_batch, batch_nb)
early_stop_epoch = batch_result == -1
# ---------------
@ -348,7 +348,7 @@ class Trainer(TrainerIO):
return
def __run_tng_batch(self, data_batch):
def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
return 0
@ -363,7 +363,7 @@ class Trainer(TrainerIO):
# forward pass
# return a scalar value and a dic with tqdm metrics
loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch)
loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb)
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# backward pass

View File

@ -7,7 +7,7 @@ from setuptools import setup, find_packages
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
setup(
name="pytorch-lightning",
version='0.1.dev1821',
version='0.1.dev1822',
description="The Keras for ML researchers using PyTorch",
author="William Falcon",
author_email="waf2107@columbia.edu",