refactor _plugin_job() with an ExitStack

Summary:
The number of context managers in that function is getting annoying,
because of all the nested with-statements. And we're about to add
another in the next diff for tmp dirs. Contextlib provides an ExitStack
class to help deal with lots of context managers.

Reviewers: sean

Reviewed By: sean

Differential Revision: https://phabricator.buildinspace.com/D152
This commit is contained in:
Jack O'Connor 2014-12-22 13:44:30 -08:00
parent ee7b9fb406
commit 8c64db06d8
1 changed files with 56 additions and 35 deletions

View File

@ -41,8 +41,8 @@ def plugin_fetch(plugin_context, module_type, module_fields, dest,
@asyncio.coroutine
def plugin_get_reup_fields(plugin_context, module_type, module_fields,
display_handle):
with tempfile.TemporaryDirectory(dir=plugin_context.tmp_root) as tmp_dir:
output_path = os.path.join(tmp_dir, 'reup_output')
with tmp_dir(plugin_context) as output_file_dir:
output_path = os.path.join(output_file_dir, 'reup_output')
env = {'PERU_REUP_OUTPUT': output_path}
yield from _plugin_job(
plugin_context, module_type, module_fields, 'reup', env,
@ -64,29 +64,30 @@ def plugin_get_reup_fields(plugin_context, module_type, module_fields,
@asyncio.coroutine
def _plugin_job(plugin_context, module_type, module_fields, command, env,
display_handle):
global DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX
# We take several locks and other context managers in here. Using an
# ExitStack saves us from indentation hell.
with contextlib.ExitStack() as stack:
definition = _get_plugin_definition(module_type, module_fields,
command)
exe = _get_plugin_exe(definition, command)
definition = _get_plugin_definition(module_type, module_fields, command)
# For Windows to run scripts with the right interpreter, we need to run
# as a shell command, rather than exec.
shell_command_line = subprocess.list2cmdline([exe])
exe = _get_plugin_exe(definition, command)
# For Windows to run scripts with the right interpreter, we need to use run
# as a shell command, rather than exec.
shell_command_line = subprocess.list2cmdline([exe])
complete_env = _plugin_env(
plugin_context, definition, module_fields, command)
complete_env.update(env)
complete_env = _plugin_env(definition, module_fields, command)
complete_env.update({
'PERU_PLUGIN_CACHE': _plugin_cache_path(
plugin_context, definition, module_fields)})
complete_env.update(env)
# Use a lock to protect the plugin cache. It would be unsafe for two
# jobs to read/write to the same plugin cache dir at the same time. The
# lock (and the cache dir) are both keyed off the module's "cache
# fields" as defined by plugin.yaml. For plugins that don't define
# cacheable fields, there is no cache dir (it's set to /dev/null) and
# the cache lock is a no-op.
stack.enter_context((yield from _plugin_cache_lock(
plugin_context, definition, module_fields)))
# Use a lock to protect the plugin cache. It would be unsafe for two jobs
# to read/write to the same plugin cache dir at the same time. The lock
# (and the cache dir) are both keyed off the module's "cache fields" as
# defined by plugin.yaml. For plugins that don't define cacheable fields,
# there is no cache dir (it's set to /dev/null) and the cache lock is a
# no-op.
cache_lock = _plugin_cache_lock(plugin_context, definition, module_fields)
with (yield from cache_lock):
# Use a semaphore to limit the number of jobs that can run in parallel.
# Most plugin fetches hit the network, and for performance reasons we
# don't want to fire off too many network requests at once. See
@ -94,19 +95,20 @@ def _plugin_job(plugin_context, module_type, module_fields, command, env,
# parallelism with the --jobs flag. It's important that this is the
# last lock taken before starting a job, otherwise we might waste a job
# slot just waiting on other locks.
with (yield from plugin_context.parallelism_semaphore):
DEBUG_PARALLEL_COUNT += 1
DEBUG_PARALLEL_MAX = max(DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX)
stack.enter_context((yield from plugin_context.parallelism_semaphore))
try:
yield from create_subprocess_with_handle(
shell_command_line, display_handle, cwd=plugin_context.cwd,
env=complete_env, shell=True)
except subprocess.CalledProcessError as e:
raise PluginRuntimeError(module_type, module_fields,
e.returncode, e.output)
finally:
DEBUG_PARALLEL_COUNT -= 1
# We use this debug counter for our parallelism tests. It's important
# that it comes after all locks have been taken (so the job it's
# counting is actually running).
stack.enter_context(debug_parallel_count_context())
try:
yield from create_subprocess_with_handle(
shell_command_line, display_handle, cwd=plugin_context.cwd,
env=complete_env, shell=True)
except subprocess.CalledProcessError as e:
raise PluginRuntimeError(
module_type, module_fields, e.returncode, e.output)
def _get_plugin_exe(definition, command):
@ -149,12 +151,12 @@ def _validate_plugin_definition(definition, module_fields):
'Unknown module fields: ' + ', '.join(unknown_module_fields))
def _plugin_env(definition, module_fields, command):
def _plugin_env(plugin_context, plugin_definition, module_fields, command):
env = os.environ.copy()
# First, blank out all module field vars. This prevents the calling
# environment from leaking in when optional fields are undefined.
blank_module_vars = {field: '' for field in definition.fields}
blank_module_vars = {field: '' for field in plugin_definition.fields}
env.update(_format_module_fields(blank_module_vars))
# Then add in the fields that are actually defined.
env.update(_format_module_fields(module_fields))
@ -171,6 +173,10 @@ def _plugin_env(definition, module_fields, command):
# name available in the environment.
env['PERU_PLUGIN_COMMAND'] = command
# Create a persistent cache dir for saved files, like repo clones.
env['PERU_PLUGIN_CACHE'] = _plugin_cache_path(
plugin_context, plugin_definition, module_fields)
return env
@ -293,6 +299,21 @@ def debug_assert_clean_parallel_count():
"parallel count should be 0 but it's " + str(DEBUG_PARALLEL_COUNT)
@contextlib.contextmanager
def debug_parallel_count_context():
global DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX
DEBUG_PARALLEL_COUNT += 1
DEBUG_PARALLEL_MAX = max(DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX)
try:
yield
finally:
DEBUG_PARALLEL_COUNT -= 1
def tmp_dir(context):
return tempfile.TemporaryDirectory(dir=context.tmp_root)
class PluginCandidateError(PrintableError):
pass