fix warning (#3800)
This commit is contained in:
parent
0c12065efd
commit
22efce8f40
|
@ -206,26 +206,16 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
||||||
self.reducer.prepare_for_backward([])
|
self.reducer.prepare_for_backward([])
|
||||||
|
|
||||||
if output is None:
|
if output is None:
|
||||||
warn_missing_output(fx_called)
|
warn_missing_output(f'{fx_called} returned None. Did you forget to re')
|
||||||
|
|
||||||
m = f'{fx_called} returned None. Did you forget to re'
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def warn_missing_output(fx_called):
|
def warn_missing_output(fx_called):
|
||||||
if fx_called == 'training_step':
|
if fx_called == 'training_step':
|
||||||
m = """
|
warning_cache.warn("Your training_step returned None. You should instead do:\n"
|
||||||
Your training_step returned None. You should instead do:
|
"`return loss`\n or\n `return TrainResult`")
|
||||||
return loss
|
|
||||||
or
|
|
||||||
return TrainResult
|
|
||||||
"""
|
|
||||||
elif fx_called in ['validation_step', 'test_step']:
|
elif fx_called in ['validation_step', 'test_step']:
|
||||||
m = f"""
|
warning_cache.warn(f"Your {fx_called} returned None. You should instead do:\n `return EvalResult")
|
||||||
Your {fx_called} returned None. You should instead do:
|
|
||||||
return EvalResult
|
|
||||||
"""
|
|
||||||
warning_cache.warn(m)
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover
|
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no-cover
|
||||||
|
|
Loading…
Reference in New Issue