diff --git a/infra/helper.py b/infra/helper.py index 6048d9771..6c2950ab4 100755 --- a/infra/helper.py +++ b/infra/helper.py @@ -64,12 +64,56 @@ LANGUAGES_WITH_COVERAGE_SUPPORT = ['c', 'c++', 'go', 'rust'] # pylint: disable=too-many-lines -def main(): # pylint: disable=too-many-branches,too-many-return-statements,too-many-statements +def main(): # pylint: disable=too-many-branches,too-many-return-statements """Get subcommand from program arguments and do it.""" os.chdir(OSS_FUZZ_DIR) if not os.path.exists(BUILD_DIR): os.mkdir(BUILD_DIR) + args = parse_args() + + # We have different default values for `sanitizer` depending on the `engine`. + # Some commands do not have `sanitizer` argument, so `hasattr` is necessary. + if hasattr(args, 'sanitizer') and not args.sanitizer: + if args.engine == 'dataflow': + args.sanitizer = 'dataflow' + else: + args.sanitizer = 'address' + + if args.command == 'generate': + return generate(args) + if args.command == 'build_image': + return build_image(args) + if args.command == 'build_fuzzers': + return build_fuzzers(args) + if args.command == 'check_build': + return check_build(args) + if args.command == 'download_corpora': + return download_corpora(args) + if args.command == 'run_fuzzer': + return run_fuzzer(args) + if args.command == 'coverage': + return coverage(args) + if args.command == 'reproduce': + return reproduce(args) + if args.command == 'shell': + return shell(args) + if args.command == 'pull_images': + return pull_images(args) + + return 0 + + +def parse_args(args=None): + """Parses args using argparser and returns parsed args.""" + # Use default argument None for args so that in production, argparse does its + # normal behavior, but unittesting is easier. + parser = get_parser() + return parser.parse_args(args) + + +def get_parser(): # pylint: disable=too-many-statements + """Returns an argparse parser.""" parser = argparse.ArgumentParser('helper.py', description='oss-fuzz helpers') subparsers = parser.add_subparsers(dest='command') @@ -192,39 +236,7 @@ def main(): # pylint: disable=too-many-branches,too-many-return-statements,too- _add_environment_args(shell_parser) subparsers.add_parser('pull_images', help='Pull base images.') - - args = parser.parse_args() - - # We have different default values for `sanitizer` depending on the `engine`. - # Some commands do not have `sanitizer` argument, so `hasattr` is necessary. - if hasattr(args, 'sanitizer') and not args.sanitizer: - if args.engine == 'dataflow': - args.sanitizer = 'dataflow' - else: - args.sanitizer = 'address' - - if args.command == 'generate': - return generate(args) - if args.command == 'build_image': - return build_image(args) - if args.command == 'build_fuzzers': - return build_fuzzers(args) - if args.command == 'check_build': - return check_build(args) - if args.command == 'download_corpora': - return download_corpora(args) - if args.command == 'run_fuzzer': - return run_fuzzer(args) - if args.command == 'coverage': - return coverage(args) - if args.command == 'reproduce': - return reproduce(args) - if args.command == 'shell': - return shell(args) - if args.command == 'pull_images': - return pull_images(args) - - return 0 + return parser def is_base_image(image_name): diff --git a/infra/helper_test.py b/infra/helper_test.py new file mode 100644 index 000000000..d899a835b --- /dev/null +++ b/infra/helper_test.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for helper.py""" + +import unittest +from unittest import mock + +import helper + + +class TestShell(unittest.TestCase): + """Tests 'shell' command.""" + + @mock.patch('helper.docker_run') + @mock.patch('helper.build_image_impl') + def test_base_runner_debug(self, mocked_build_image_impl, _): + """Tests that shell base-runner-debug works as intended.""" + image_name = 'base-runner-debug' + unparsed_args = ['shell', image_name] + args = helper.parse_args(unparsed_args) + args.sanitizer = 'address' + result = helper.shell(args) + mocked_build_image_impl.assert_called_with(image_name) + self.assertEqual(result, 0)