diff --git a/peru/plugin.py b/peru/plugin.py index 897ce5a..e98762d 100644 --- a/peru/plugin.py +++ b/peru/plugin.py @@ -41,8 +41,8 @@ def plugin_fetch(plugin_context, module_type, module_fields, dest, @asyncio.coroutine def plugin_get_reup_fields(plugin_context, module_type, module_fields, display_handle): - with tempfile.TemporaryDirectory(dir=plugin_context.tmp_root) as tmp_dir: - output_path = os.path.join(tmp_dir, 'reup_output') + with tmp_dir(plugin_context) as output_file_dir: + output_path = os.path.join(output_file_dir, 'reup_output') env = {'PERU_REUP_OUTPUT': output_path} yield from _plugin_job( plugin_context, module_type, module_fields, 'reup', env, @@ -64,29 +64,30 @@ def plugin_get_reup_fields(plugin_context, module_type, module_fields, @asyncio.coroutine def _plugin_job(plugin_context, module_type, module_fields, command, env, display_handle): - global DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX + # We take several locks and other context managers in here. Using an + # ExitStack saves us from indentation hell. + with contextlib.ExitStack() as stack: + definition = _get_plugin_definition(module_type, module_fields, + command) + exe = _get_plugin_exe(definition, command) - definition = _get_plugin_definition(module_type, module_fields, command) + # For Windows to run scripts with the right interpreter, we need to run + # as a shell command, rather than exec. + shell_command_line = subprocess.list2cmdline([exe]) - exe = _get_plugin_exe(definition, command) - # For Windows to run scripts with the right interpreter, we need to use run - # as a shell command, rather than exec. - shell_command_line = subprocess.list2cmdline([exe]) + complete_env = _plugin_env( + plugin_context, definition, module_fields, command) + complete_env.update(env) - complete_env = _plugin_env(definition, module_fields, command) - complete_env.update({ - 'PERU_PLUGIN_CACHE': _plugin_cache_path( - plugin_context, definition, module_fields)}) - complete_env.update(env) + # Use a lock to protect the plugin cache. It would be unsafe for two + # jobs to read/write to the same plugin cache dir at the same time. The + # lock (and the cache dir) are both keyed off the module's "cache + # fields" as defined by plugin.yaml. For plugins that don't define + # cacheable fields, there is no cache dir (it's set to /dev/null) and + # the cache lock is a no-op. + stack.enter_context((yield from _plugin_cache_lock( + plugin_context, definition, module_fields))) - # Use a lock to protect the plugin cache. It would be unsafe for two jobs - # to read/write to the same plugin cache dir at the same time. The lock - # (and the cache dir) are both keyed off the module's "cache fields" as - # defined by plugin.yaml. For plugins that don't define cacheable fields, - # there is no cache dir (it's set to /dev/null) and the cache lock is a - # no-op. - cache_lock = _plugin_cache_lock(plugin_context, definition, module_fields) - with (yield from cache_lock): # Use a semaphore to limit the number of jobs that can run in parallel. # Most plugin fetches hit the network, and for performance reasons we # don't want to fire off too many network requests at once. See @@ -94,19 +95,20 @@ def _plugin_job(plugin_context, module_type, module_fields, command, env, # parallelism with the --jobs flag. It's important that this is the # last lock taken before starting a job, otherwise we might waste a job # slot just waiting on other locks. - with (yield from plugin_context.parallelism_semaphore): - DEBUG_PARALLEL_COUNT += 1 - DEBUG_PARALLEL_MAX = max(DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX) + stack.enter_context((yield from plugin_context.parallelism_semaphore)) - try: - yield from create_subprocess_with_handle( - shell_command_line, display_handle, cwd=plugin_context.cwd, - env=complete_env, shell=True) - except subprocess.CalledProcessError as e: - raise PluginRuntimeError(module_type, module_fields, - e.returncode, e.output) - finally: - DEBUG_PARALLEL_COUNT -= 1 + # We use this debug counter for our parallelism tests. It's important + # that it comes after all locks have been taken (so the job it's + # counting is actually running). + stack.enter_context(debug_parallel_count_context()) + + try: + yield from create_subprocess_with_handle( + shell_command_line, display_handle, cwd=plugin_context.cwd, + env=complete_env, shell=True) + except subprocess.CalledProcessError as e: + raise PluginRuntimeError( + module_type, module_fields, e.returncode, e.output) def _get_plugin_exe(definition, command): @@ -149,12 +151,12 @@ def _validate_plugin_definition(definition, module_fields): 'Unknown module fields: ' + ', '.join(unknown_module_fields)) -def _plugin_env(definition, module_fields, command): +def _plugin_env(plugin_context, plugin_definition, module_fields, command): env = os.environ.copy() # First, blank out all module field vars. This prevents the calling # environment from leaking in when optional fields are undefined. - blank_module_vars = {field: '' for field in definition.fields} + blank_module_vars = {field: '' for field in plugin_definition.fields} env.update(_format_module_fields(blank_module_vars)) # Then add in the fields that are actually defined. env.update(_format_module_fields(module_fields)) @@ -171,6 +173,10 @@ def _plugin_env(definition, module_fields, command): # name available in the environment. env['PERU_PLUGIN_COMMAND'] = command + # Create a persistent cache dir for saved files, like repo clones. + env['PERU_PLUGIN_CACHE'] = _plugin_cache_path( + plugin_context, plugin_definition, module_fields) + return env @@ -293,6 +299,21 @@ def debug_assert_clean_parallel_count(): "parallel count should be 0 but it's " + str(DEBUG_PARALLEL_COUNT) +@contextlib.contextmanager +def debug_parallel_count_context(): + global DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX + DEBUG_PARALLEL_COUNT += 1 + DEBUG_PARALLEL_MAX = max(DEBUG_PARALLEL_COUNT, DEBUG_PARALLEL_MAX) + try: + yield + finally: + DEBUG_PARALLEL_COUNT -= 1 + + +def tmp_dir(context): + return tempfile.TemporaryDirectory(dir=context.tmp_root) + + class PluginCandidateError(PrintableError): pass