diff --git a/peru/imports.py b/peru/imports.py index dd6ff99..4246ebd 100644 --- a/peru/imports.py +++ b/peru/imports.py @@ -7,12 +7,8 @@ from .merge import merge_imports_tree @asyncio.coroutine -def checkout(runtime, *, path=None, imports=None): - if path is None: - path = runtime.root - if imports is None: - imports = runtime.imports - target_trees = yield from get_trees(runtime, runtime.imports.targets) +def checkout(runtime, scope, imports, path): + target_trees = yield from get_trees(runtime, scope, imports.targets) imports_tree = merge_imports_tree(runtime.cache, imports, target_trees) last_imports_tree = _get_last_imports(runtime) runtime.cache.export_tree(imports_tree, path, last_imports_tree, @@ -21,15 +17,15 @@ def checkout(runtime, *, path=None, imports=None): @asyncio.coroutine -def get_trees(runtime, targets): - futures = [get_tree(runtime, target) for target in targets] +def get_trees(runtime, scope, targets): + futures = [get_tree(runtime, scope, target) for target in targets] trees = yield from stable_gather(*futures) return dict(zip(targets, trees)) @asyncio.coroutine -def get_tree(runtime, target_str): - module, rules = _parse_target(runtime, target_str) +def get_tree(runtime, scope, target_str): + module, rules = yield from scope.parse_target(runtime, target_str) tree = yield from module.get_tree(runtime) if module.default_rule: tree = yield from module.default_rule.get_tree(runtime, tree) @@ -38,13 +34,6 @@ def get_tree(runtime, target_str): return tree -def _parse_target(runtime, target_str): - module_name, *rule_names = target_str.split('|') - module = runtime.get_modules([module_name])[0] - rules = runtime.get_rules(rule_names) - return module, rules - - def _last_imports_path(runtime): return os.path.join(runtime.peru_dir, 'lastimports') diff --git a/peru/main.py b/peru/main.py index 175b4a5..c678eee 100644 --- a/peru/main.py +++ b/peru/main.py @@ -11,7 +11,7 @@ from . import async from . import compat from .error import PrintableError from . import imports -from .parser import build_imports +from .parser import parse_file, build_imports from .runtime import Runtime __doc__ = """\ @@ -75,6 +75,7 @@ class Main: matching_command = find_matching_command(self.args) if matching_command: self.runtime = Runtime(self.args, env) + self.scope, self.imports = parse_file(self.runtime.peru_file) async.run_task(matching_command(self)) else: if self.args["--version"]: @@ -85,14 +86,16 @@ class Main: @command("sync") def do_sync(self): - yield from imports.checkout(self.runtime) + yield from imports.checkout( + self.runtime, self.scope, self.imports, self.runtime.root) @command('reup') def do_reup(self): - if not self.args['']: - modules = self.runtime.modules.values() + names = self.args[''] + if not names: + modules = self.scope.modules.values() else: - modules = self.runtime.get_modules(self.args['']) + modules = self.scope.get_modules_for_reup(names) futures = [module.reup(self.runtime) for module in modules] yield from async.stable_gather(*futures) if not self.args['--nosync']: @@ -121,7 +124,7 @@ class Main: else: dest = self.args[''] tree = yield from imports.get_tree( - self.runtime, self.args['']) + self.runtime, self.scope, self.args['']) self.runtime.cache.export_tree(tree, dest, force=self.runtime.force) if not self.args['']: print(dest) @@ -129,7 +132,8 @@ class Main: @command('clean') def do_clean(self): empty_imports = build_imports({}) - yield from imports.checkout(self.runtime, imports=empty_imports) + yield from imports.checkout( + self.runtime, self.scope, empty_imports, self.runtime.root) def get_version(): diff --git a/peru/parser.py b/peru/parser.py index e81f1d7..ba97d7d 100644 --- a/peru/parser.py +++ b/peru/parser.py @@ -6,6 +6,7 @@ import yaml from .error import PrintableError from .module import Module from .rule import Rule +from .scope import Scope DEFAULT_PERU_FILE_NAME = 'peru.yaml' @@ -15,10 +16,6 @@ class ParserError(PrintableError): pass -ParseResult = collections.namedtuple( - "ParseResult", ["modules", "rules", "imports"]) - - def parse_file(file_path, name_prefix=""): with open(file_path) as f: return parse_string(f.read(), name_prefix) @@ -41,7 +38,7 @@ def _parse_toplevel(blob, name_prefix): if blob: raise ParserError("Unknown toplevel fields: " + ", ".join(blob.keys())) - return ParseResult(modules, rules, imports) + return Scope(modules, rules), imports def _extract_named_rules(blob, name_prefix): diff --git a/peru/runtime.py b/peru/runtime.py index 1537d1f..0f55319 100644 --- a/peru/runtime.py +++ b/peru/runtime.py @@ -24,9 +24,6 @@ class Runtime: 'PERU_DIR', os.path.join(self.root, '.peru')) compat.makedirs(self.peru_dir) - self.modules, self.rules, self.imports = \ - parser.parse_file(self.peru_file) - cache_dir = env.get('PERU_CACHE', os.path.join(self.peru_dir, 'cache')) self.cache = cache.Cache(cache_dir) @@ -91,22 +88,6 @@ class Runtime: plugin_cache_locks=self.plugin_cache_locks, tmp_root=self._tmp_root) - def get_rules(self, rule_names): - rules = [] - for name in rule_names: - if name not in self.rules: - raise PrintableError('rule "{}" does not exist'.format(name)) - rules.append(self.rules[name]) - return rules - - def get_modules(self, names): - modules = [] - for name in names: - if name not in self.modules: - raise PrintableError('module "{}" does not exist'.format(name)) - modules.append(self.modules[name]) - return modules - def find_peru_file(start_dir, name): '''Walk up the directory tree until we find a file of the given name.''' diff --git a/peru/scope.py b/peru/scope.py new file mode 100644 index 0000000..fed3809 --- /dev/null +++ b/peru/scope.py @@ -0,0 +1,92 @@ +import asyncio + +from .error import PrintableError + + +SCOPE_SEPARATOR = '.' +RULE_SEPARATOR = '|' + + +class Scope: + '''A Scope holds the elements that are parsed out of a single peru.yaml + file. This is kept separate from a Runtime, because recursive modules need + to work with a Scope that makes sense to them, rather than a single global + scope.''' + + def __init__(self, modules, rules): + self.modules = modules + self.rules = rules + + @asyncio.coroutine + def parse_target(self, runtime, target_str): + '''A target is a pipeline of a module into zero or more rules, and each + module and rule can itself be scoped with zero or more module names.''' + pipeline_parts = target_str.split(RULE_SEPARATOR) + module = yield from self.resolve_module( + runtime, pipeline_parts[0], target_str) + rules = tuple((yield from self.resolve_rule(runtime, part)) + for part in pipeline_parts[1:]) + return module, rules + + @asyncio.coroutine + def resolve_module(self, runtime, module_str, logging_target_name=None): + logging_target_name = logging_target_name or module_str + module_names = module_str.split(SCOPE_SEPARATOR) + return (yield from self._resolve_module_from_names( + runtime, module_names, logging_target_name)) + + @asyncio.coroutine + def _resolve_module_from_names(self, runtime, module_names, + logging_target_name): + next_module = self.modules[module_names[0]] + for name in module_names[1:]: + next_scope = yield from _get_scope_or_fail( + runtime, logging_target_name, next_module) + if name not in next_scope.modules: + _error(logging_target_name, 'module {} not found in {}', name, + next_module.name) + next_module = next_scope.modules[name] + return next_module + + @asyncio.coroutine + def resolve_rule(self, runtime, rule_str, logging_target_name=None): + logging_target_name = logging_target_name or rule_str + *module_names, rule_name = rule_str.split(SCOPE_SEPARATOR) + scope = self + location_str = '' + if module_names: + module = yield from self._resolve_module_from_names( + runtime, module_names, logging_target_name) + scope = yield from _get_scope_or_fail( + runtime, logging_target_name, module) + location_str = ' in module ' + module.name + if rule_name not in scope.rules: + _error(logging_target_name, 'rule {} not found{}', rule_name, + location_str) + return scope.rules[rule_name] + + def get_modules_for_reup(self, names): + for name in names: + if SCOPE_SEPARATOR in name: + raise PrintableError( + 'Can\'t reup module "{}"; it belongs to another project.' + .format(name)) + if name not in self.modules: + raise PrintableError( + 'Module "{}" isn\'t defined.'.format(name)) + return [self.modules[name] for name in names] + + +@asyncio.coroutine +def _get_scope_or_fail(runtime, logging_target_name, module): + scope, imports = yield from module.parse_peru_file(runtime) + if not scope: + _error(logging_target_name, 'module {} is not a peru project', + module.name) + return scope + + +def _error(logging_target_name, text, *text_format_args): + text = text.format(*text_format_args) + raise PrintableError('Error in target {}: {}'.format( + logging_target_name, text)) diff --git a/tests/test_parser.py b/tests/test_parser.py index 8efe8f5..3ba717e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -9,19 +9,19 @@ from peru.rule import Rule class ParserTest(unittest.TestCase): def test_parse_empty_file(self): - result = parse_string('') - self.assertDictEqual(result.modules, {}) - self.assertDictEqual(result.rules, {}) - self.assertEqual(result.imports, build_imports({})) + scope, imports = parse_string('') + self.assertDictEqual(scope.modules, {}) + self.assertDictEqual(scope.rules, {}) + self.assertEqual(imports, build_imports({})) def test_parse_rule(self): input = dedent("""\ rule foo: export: out/ """) - result = parse_string(input) - self.assertIn("foo", result.rules) - rule = result.rules["foo"] + scope, imports = parse_string(input) + self.assertIn("foo", scope.rules) + rule = scope.rules["foo"] self.assertIsInstance(rule, Rule) self.assertEqual(rule.name, "foo") self.assertEqual(rule.export, "out/") @@ -32,9 +32,9 @@ class ParserTest(unittest.TestCase): url: http://www.example.com/ rev: abcdefg """) - result = parse_string(input) - self.assertIn("foo", result.modules) - module = result.modules["foo"] + scope, imports = parse_string(input) + self.assertIn("foo", scope.modules) + module = scope.modules["foo"] self.assertIsInstance(module, Module) self.assertEqual(module.name, "foo") self.assertEqual(module.type, "sometype") @@ -47,9 +47,9 @@ class ParserTest(unittest.TestCase): git module bar: export: bar """) - result = parse_string(input) - self.assertIn("bar", result.modules) - module = result.modules["bar"] + scope, imports = parse_string(input) + self.assertIn("bar", scope.modules) + module = scope.modules["bar"] self.assertIsInstance(module, Module) self.assertIsInstance(module.default_rule, Rule) self.assertEqual(module.default_rule.export, "bar") @@ -59,29 +59,29 @@ class ParserTest(unittest.TestCase): imports: foo: bar/ """) - result = parse_string(input) - self.assertDictEqual(result.modules, {}) - self.assertDictEqual(result.rules, {}) - self.assertEqual(result.imports, build_imports({'foo': 'bar/'})) + scope, imports = parse_string(input) + self.assertDictEqual(scope.modules, {}) + self.assertDictEqual(scope.rules, {}) + self.assertEqual(imports, build_imports({'foo': 'bar/'})) def test_parse_list_imports(self): input = dedent('''\ imports: - foo: bar/ ''') - result = parse_string(input) - self.assertDictEqual(result.modules, {}) - self.assertDictEqual(result.rules, {}) - self.assertEqual(result.imports, build_imports({'foo': 'bar/'})) + scope, imports = parse_string(input) + self.assertDictEqual(scope.modules, {}) + self.assertDictEqual(scope.rules, {}) + self.assertEqual(imports, build_imports({'foo': 'bar/'})) def test_parse_empty_imports(self): input = dedent('''\ imports: ''') - result = parse_string(input) - self.assertDictEqual(result.modules, {}) - self.assertDictEqual(result.rules, {}) - self.assertEqual(result.imports, build_imports({})) + scope, imports = parse_string(input) + self.assertDictEqual(scope.modules, {}) + self.assertDictEqual(scope.rules, {}) + self.assertEqual(imports, build_imports({})) def test_parse_wrong_type_imports_throw(self): with self.assertRaises(ParserError): @@ -178,8 +178,8 @@ class ParserTest(unittest.TestCase): rule bar: export: more stuff ''') - result = parse_string(input, name_prefix='x') + scope, imports = parse_string(input, name_prefix='x') # Lookup keys should be unaffected, but the names that modules and # rules give for themselves should have the prefix. - assert result.modules['foo'].name == 'xfoo' - assert result.rules['bar'].name == 'xbar' + assert scope.modules['foo'].name == 'xfoo' + assert scope.rules['bar'].name == 'xbar' diff --git a/tests/test_scope.py b/tests/test_scope.py new file mode 100644 index 0000000..24cfe77 --- /dev/null +++ b/tests/test_scope.py @@ -0,0 +1,75 @@ +import asyncio +import unittest + +from peru.async import run_task +import peru.scope + + +class ScopeTest(unittest.TestCase): + def test_parse_target(self): + scope = scope_tree_to_scope({ + 'modules': { + 'a': { + 'modules': { + 'b': { + 'modules': {'c': {}}, + 'rules': ['r'], + } + } + } + } + }) + c, (r,) = run_task(scope.parse_target(DummyRuntime(), 'a.b.c|a.b.r')) + assert type(c) is DummyModule and c.name == 'a.b.c' + assert type(r) is DummyRule and r.name == 'a.b.r' + + +def scope_tree_to_scope(tree, prefix=""): + '''This function is for generating dummy scope/module/rule hierarchies for + testing. A scope tree contains a modules dictionary and a rules list, both + optional. The values of the modules dictionary are themselves scope trees. + So if module A contains module B and rule R, that's represented as: + + { + 'modules': { + 'A': { + 'modules': { + 'B': {}, + }, + 'rules': ['R'], + } + } + } + ''' + modules = {} + if 'modules' in tree: + for module_name, sub_tree in tree['modules'].items(): + full_name = prefix + module_name + new_prefix = full_name + peru.scope.SCOPE_SEPARATOR + module_scope = scope_tree_to_scope(sub_tree, new_prefix) + modules[module_name] = DummyModule(full_name, module_scope) + rules = {} + if 'rules' in tree: + for rule_name in tree['rules']: + full_name = prefix + rule_name + rules[rule_name] = DummyRule(full_name) + return peru.scope.Scope(modules, rules) + + +class DummyModule: + def __init__(self, name, scope): + self.name = name + self.scope = scope + + @asyncio.coroutine + def parse_peru_file(self, dummy_runtime): + return self.scope, None + + +class DummyRule: + def __init__(self, name): + self.name = name + + +class DummyRuntime: + pass