fixed multiprocessing import
This commit is contained in:
parent
f2134a4ddd
commit
0a03042bf7
|
@ -41,7 +41,6 @@ def main(hparams):
|
|||
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
|
||||
checkpoint = ModelCheckpoint(
|
||||
filepath=model_save_path,
|
||||
save_function=None,
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
monitor='val_acc',
|
||||
|
|
|
@ -17,7 +17,7 @@ np.random.seed(SEED)
|
|||
# ---------------------
|
||||
# DEFINE MODEL HERE
|
||||
# ---------------------
|
||||
from examples.new_project_templates.lightning_module_template import LightningTemplateModel
|
||||
from lightning_module_template import LightningTemplateModel
|
||||
# ---------------------
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
|
@ -58,9 +58,7 @@ def main(hparams, cluster, results_dict):
|
|||
log_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
exp = Experiment(
|
||||
name='test_tube_exp',
|
||||
debug=True,
|
||||
save_dir=log_dir,
|
||||
version=0,
|
||||
autosave=False,
|
||||
description='test demo'
|
||||
)
|
||||
|
@ -84,7 +82,6 @@ def main(hparams, cluster, results_dict):
|
|||
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
|
||||
checkpoint = ModelCheckpoint(
|
||||
filepath=model_save_path,
|
||||
save_function=None,
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
monitor=hparams.model_save_monitor_value,
|
||||
|
@ -102,7 +99,7 @@ def main(hparams, cluster, results_dict):
|
|||
cluster=cluster,
|
||||
checkpoint_callback=checkpoint,
|
||||
early_stop_callback=early_stop,
|
||||
gpus=gpu_list
|
||||
gpus=gpu_list,
|
||||
)
|
||||
|
||||
# train model
|
||||
|
|
|
@ -128,13 +128,15 @@ class Trainer(TrainerIO):
|
|||
def __tng_tqdm_dic(self):
|
||||
tqdm_dic = {
|
||||
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
||||
'gpu': '{}'.format(self.current_gpu_name),
|
||||
'v_nb': '{}'.format(self.experiment.version),
|
||||
'epoch': '{}'.format(self.current_epoch),
|
||||
'batch_nb':'{}'.format(self.batch_nb),
|
||||
}
|
||||
tqdm_dic.update(self.tqdm_metrics)
|
||||
|
||||
if self.on_gpu:
|
||||
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
|
||||
|
||||
return tqdm_dic
|
||||
|
||||
def __layout_bookeeping(self, model):
|
||||
|
@ -371,7 +373,8 @@ class Trainer(TrainerIO):
|
|||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
self.experiment.log(metrics)
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics)
|
||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||
self.experiment.save()
|
||||
|
||||
# hook
|
||||
|
@ -398,6 +401,19 @@ class Trainer(TrainerIO):
|
|||
if stop:
|
||||
return
|
||||
|
||||
def __metrics_to_scalars(self, metrics):
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
v = v.item()
|
||||
|
||||
if type(v) is dict:
|
||||
v = self.__metrics_to_scalars(v)
|
||||
|
||||
new_metrics[k] = float(v)
|
||||
|
||||
return new_metrics
|
||||
|
||||
|
||||
def __run_tng_batch(self, data_batch, batch_nb):
|
||||
if data_batch is None:
|
||||
|
|
|
@ -1,27 +1,44 @@
|
|||
atomicwrites==1.2.1
|
||||
attrs==18.2.0
|
||||
certifi==2018.11.29
|
||||
cffi==1.11.5
|
||||
absl-py==0.7.1
|
||||
astor==0.8.0
|
||||
bleach==3.1.0
|
||||
certifi==2019.6.16
|
||||
cffi==1.12.3
|
||||
chardet==3.0.4
|
||||
docutils==0.14
|
||||
gast==0.2.2
|
||||
google-pasta==0.1.7
|
||||
grpcio==1.21.1
|
||||
h5py==2.9.0
|
||||
imageio==2.4.1
|
||||
mkl-fft==1.0.6
|
||||
idna==2.8
|
||||
imageio==2.5.0
|
||||
Keras-Applications==1.0.8
|
||||
Keras-Preprocessing==1.1.0
|
||||
Markdown==3.1.1
|
||||
mkl-fft==1.0.12
|
||||
mkl-random==1.0.2
|
||||
more-itertools==5.0.0
|
||||
numpy==1.15.4
|
||||
numpy==1.16.4
|
||||
olefile==0.46
|
||||
pandas==0.23.4
|
||||
Pillow==5.3.0
|
||||
pluggy==0.8.0
|
||||
py==1.7.0
|
||||
pandas==0.24.2
|
||||
Pillow==6.0.0
|
||||
pkginfo==1.5.0.1
|
||||
protobuf==3.8.0
|
||||
pycparser==2.19
|
||||
pytest==4.0.2
|
||||
python-dateutil==2.7.5
|
||||
pytz==2018.7
|
||||
scikit-learn==0.20.2
|
||||
scipy==1.2.0
|
||||
Pygments==2.4.1
|
||||
python-dateutil==2.8.0
|
||||
pytz==2019.1
|
||||
readme-renderer==24.0
|
||||
requests==2.22.0
|
||||
requests-toolbelt==0.9.1
|
||||
six==1.12.0
|
||||
sklearn==0.0
|
||||
test-tube==0.6282
|
||||
torch==1.0.0
|
||||
torchvision==0.2.1
|
||||
tqdm==4.28.1
|
||||
tensorboard==1.14.0
|
||||
tensorboardX==1.7
|
||||
tensorflow==1.14.0
|
||||
tensorflow-estimator==1.14.0
|
||||
termcolor==1.1.0
|
||||
test-tube==0.643
|
||||
tqdm==4.32.1
|
||||
twine==1.13.0
|
||||
urllib3==1.25.3
|
||||
webencodings==0.5.1
|
||||
Werkzeug==0.15.4
|
||||
wrapt==1.11.2
|
||||
|
|
Loading…
Reference in New Issue