diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 9e9f641409..657b35e9fe 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -317,7 +317,7 @@ class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no c depr_arg_names = blacklist + added_args depr_arg_names = set(depr_arg_names) - allowed_types = (str, float, int, bool) + allowed_types = (str, int, float, bool) # TODO: get "help" from docstring :) for arg, arg_types, arg_default in ( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8b45d10986..7fc8b93b4d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -736,7 +736,7 @@ class Trainer( blacklist = ['kwargs'] depr_arg_names = cls.get_deprecated_arg_names() + blacklist - allowed_types = (str, float, int, bool) + allowed_types = (str, int, float, bool) # TODO: get "help" from docstring :) for arg, arg_types, arg_default in (