diff --git a/infra/base-images/base-builder/detect_repo.py b/infra/base-images/base-builder/detect_repo.py index 8969e974f..e677e1023 100644 --- a/infra/base-images/base-builder/detect_repo.py +++ b/infra/base-images/base-builder/detect_repo.py @@ -107,20 +107,25 @@ def get_repo(repo_path): return None -def check_for_repo_name(repo_path, repo_name): - """Check to see if the repo_name matches the remote repository repo name. +def check_for_repo_name(repo_path, expected_repo_name): + """Returns True if the repo at |repo_path| repo_name matches + |expected_repo_name|. Args: - repo_path: The directory of the git repo. - repo_name: The name of the target git repo. + repo_path: The directory of a git repo. + expected_repo_name: The name of the target git repo. """ if not os.path.exists(os.path.join(repo_path, '.git')): return False - out, _ = execute(['git', 'config', '--get', 'remote.origin.url'], - location=repo_path) - out = out.split('/')[-1].replace('.git', '').rstrip() - return out == repo_name + repo_url, _ = execute(['git', 'config', '--get', 'remote.origin.url'], + location=repo_path) + # Handle two common cases: + # https://github.com/google/syzkaller/ + # https://github.com/google/syzkaller.git + repo_url = repo_url.replace('.git', '').rstrip().rstrip('/') + actual_repo_name = repo_url.split('/')[-1] + return actual_repo_name == expected_repo_name def check_for_commit(repo_path, commit): diff --git a/infra/base-images/base-builder/detect_repo_test.py b/infra/base-images/base-builder/detect_repo_test.py index 21f64af44..0243b3ac5 100644 --- a/infra/base-images/base-builder/detect_repo_test.py +++ b/infra/base-images/base-builder/detect_repo_test.py @@ -23,6 +23,7 @@ import re import sys import tempfile import unittest +from unittest import mock import detect_repo @@ -36,6 +37,33 @@ import test_repos # pylint: enable=wrong-import-position +class TestCheckForRepoName(unittest.TestCase): + """Tests for check_for_repo_name.""" + + @mock.patch('os.path.exists', return_value=True) + @mock.patch('detect_repo.execute', + return_value=('https://github.com/google/syzkaller/', None)) + def test_go_get_style_url(self, _, __): + """Tests that check_for_repo_name works on repos that were downloaded using + go get.""" + self.assertTrue(detect_repo.check_for_repo_name('fake-path', 'syzkaller')) + + @mock.patch('os.path.exists', return_value=True) + @mock.patch('detect_repo.execute', + return_value=('https://github.com/google/syzkaller', None)) + def test_missing_git_and_slash_url(self, _, __): + """Tests that check_for_repo_name works on repos who's URLs do not end in + ".git" or "/".""" + self.assertTrue(detect_repo.check_for_repo_name('fake-path', 'syzkaller')) + + @mock.patch('os.path.exists', return_value=True) + @mock.patch('detect_repo.execute', + return_value=('https://github.com/google/syzkaller.git', None)) + def test_normal_style_repo_url(self, _, __): + """Tests that check_for_repo_name works on normally cloned repos.""" + self.assertTrue(detect_repo.check_for_repo_name('fake-path', 'syzkaller')) + + @unittest.skipIf(not os.getenv('INTEGRATION_TESTS'), 'INTEGRATION_TESTS=1 not set') class DetectRepoIntegrationTest(unittest.TestCase):