fixed multiprocessing import

This commit is contained in:
William Falcon 2019-06-29 17:33:10 -04:00
parent f2134a4ddd
commit 0a03042bf7
5 changed files with 60 additions and 31 deletions

View File

@ -41,7 +41,6 @@ def main(hparams):
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
verbose=True,
monitor='val_acc',

View File

@ -17,7 +17,7 @@ np.random.seed(SEED)
# ---------------------
# DEFINE MODEL HERE
# ---------------------
from examples.new_project_templates.lightning_module_template import LightningTemplateModel
from lightning_module_template import LightningTemplateModel
# ---------------------
AVAILABLE_MODELS = {
@ -58,9 +58,7 @@ def main(hparams, cluster, results_dict):
log_dir = os.path.dirname(os.path.realpath(__file__))
exp = Experiment(
name='test_tube_exp',
debug=True,
save_dir=log_dir,
version=0,
autosave=False,
description='test demo'
)
@ -84,7 +82,6 @@ def main(hparams, cluster, results_dict):
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
verbose=True,
monitor=hparams.model_save_monitor_value,
@ -102,7 +99,7 @@ def main(hparams, cluster, results_dict):
cluster=cluster,
checkpoint_callback=checkpoint,
early_stop_callback=early_stop,
gpus=gpu_list
gpus=gpu_list,
)
# train model

View File

@ -128,13 +128,15 @@ class Trainer(TrainerIO):
def __tng_tqdm_dic(self):
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'gpu': '{}'.format(self.current_gpu_name),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb':'{}'.format(self.batch_nb),
}
tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
return tqdm_dic
def __layout_bookeeping(self, model):
@ -371,7 +373,8 @@ class Trainer(TrainerIO):
metrics.update(grad_norm_dic)
# log metrics
self.experiment.log(metrics)
scalar_metrics = self.__metrics_to_scalars(metrics)
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()
# hook
@ -398,6 +401,19 @@ class Trainer(TrainerIO):
if stop:
return
def __metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
if type(v) is dict:
v = self.__metrics_to_scalars(v)
new_metrics[k] = float(v)
return new_metrics
def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:

View File

@ -1,27 +1,44 @@
atomicwrites==1.2.1
attrs==18.2.0
certifi==2018.11.29
cffi==1.11.5
absl-py==0.7.1
astor==0.8.0
bleach==3.1.0
certifi==2019.6.16
cffi==1.12.3
chardet==3.0.4
docutils==0.14
gast==0.2.2
google-pasta==0.1.7
grpcio==1.21.1
h5py==2.9.0
imageio==2.4.1
mkl-fft==1.0.6
idna==2.8
imageio==2.5.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
Markdown==3.1.1
mkl-fft==1.0.12
mkl-random==1.0.2
more-itertools==5.0.0
numpy==1.15.4
numpy==1.16.4
olefile==0.46
pandas==0.23.4
Pillow==5.3.0
pluggy==0.8.0
py==1.7.0
pandas==0.24.2
Pillow==6.0.0
pkginfo==1.5.0.1
protobuf==3.8.0
pycparser==2.19
pytest==4.0.2
python-dateutil==2.7.5
pytz==2018.7
scikit-learn==0.20.2
scipy==1.2.0
Pygments==2.4.1
python-dateutil==2.8.0
pytz==2019.1
readme-renderer==24.0
requests==2.22.0
requests-toolbelt==0.9.1
six==1.12.0
sklearn==0.0
test-tube==0.6282
torch==1.0.0
torchvision==0.2.1
tqdm==4.28.1
tensorboard==1.14.0
tensorboardX==1.7
tensorflow==1.14.0
tensorflow-estimator==1.14.0
termcolor==1.1.0
test-tube==0.643
tqdm==4.32.1
twine==1.13.0
urllib3==1.25.3
webencodings==0.5.1
Werkzeug==0.15.4
wrapt==1.11.2

View File

@ -19,7 +19,7 @@ setup(
install_requires=[
"torch>=1.0.0",
"tqdm",
"test-tube>=0.641",
"test-tube>=0.643",
"tensorflow>=1.14.0"
],
packages=find_packages(),