mirror of https://github.com/pret/pokecrystal.git
1016 lines
37 KiB
Python
1016 lines
37 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import sys
|
|
import inspect
|
|
from copy import copy
|
|
import hashlib
|
|
import random
|
|
import json
|
|
|
|
from interval_map import IntervalMap
|
|
from chars import chars, jap_chars
|
|
|
|
from romstr import (
|
|
RomStr,
|
|
AsmList,
|
|
)
|
|
|
|
from item_constants import (
|
|
item_constants,
|
|
find_item_label_by_id,
|
|
generate_item_constants,
|
|
)
|
|
|
|
from pointers import (
|
|
calculate_bank,
|
|
calculate_pointer,
|
|
)
|
|
|
|
from pksv import (
|
|
pksv_gs,
|
|
pksv_crystal,
|
|
)
|
|
|
|
from labels import (
|
|
remove_quoted_text,
|
|
line_has_comment_address,
|
|
line_has_label,
|
|
get_label_from_line,
|
|
)
|
|
|
|
from crystal import (
|
|
rom,
|
|
load_rom,
|
|
rom_until,
|
|
direct_load_rom,
|
|
parse_script_engine_script_at,
|
|
parse_text_engine_script_at,
|
|
parse_text_at2,
|
|
find_all_text_pointers_in_script_engine_script,
|
|
SingleByteParam,
|
|
HexByte,
|
|
MultiByteParam,
|
|
PointerLabelParam,
|
|
ItemLabelByte,
|
|
DollarSignByte,
|
|
DecimalParam,
|
|
rom_interval,
|
|
map_names,
|
|
Label,
|
|
scan_for_predefined_labels,
|
|
all_labels,
|
|
write_all_labels,
|
|
parse_map_header_at,
|
|
old_parse_map_header_at,
|
|
process_00_subcommands,
|
|
parse_all_map_headers,
|
|
translate_command_byte,
|
|
map_name_cleaner,
|
|
load_map_group_offsets,
|
|
load_asm,
|
|
asm,
|
|
is_valid_address,
|
|
index,
|
|
how_many_until,
|
|
grouper,
|
|
get_pokemon_constant_by_id,
|
|
generate_map_constant_labels,
|
|
get_map_constant_label_by_id,
|
|
get_id_for_map_constant_label,
|
|
calculate_pointer_from_bytes_at,
|
|
isolate_incbins,
|
|
process_incbins,
|
|
get_labels_between,
|
|
generate_diff_insert,
|
|
find_labels_without_addresses,
|
|
rom_text_at,
|
|
get_label_for,
|
|
split_incbin_line_into_three,
|
|
reset_incbins,
|
|
)
|
|
|
|
# for testing all this crap
|
|
try:
|
|
import unittest2 as unittest
|
|
except ImportError:
|
|
import unittest
|
|
|
|
# check for things we need in unittest
|
|
if not hasattr(unittest.TestCase, 'setUpClass'):
|
|
sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.")
|
|
sys.exit(1)
|
|
|
|
class TestCram(unittest.TestCase):
|
|
"this is where i cram all of my unit tests together"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
global rom
|
|
cls.rom = direct_load_rom()
|
|
rom = cls.rom
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.rom
|
|
|
|
def test_generic_useless(self):
|
|
"do i know how to write a test?"
|
|
self.assertEqual(1, 1)
|
|
|
|
def test_map_name_cleaner(self):
|
|
name = "hello world"
|
|
cleaned_name = map_name_cleaner(name)
|
|
self.assertNotEqual(name, cleaned_name)
|
|
self.failUnless(" " not in cleaned_name)
|
|
name = "Some Random Pokémon Center"
|
|
cleaned_name = map_name_cleaner(name)
|
|
self.assertNotEqual(name, cleaned_name)
|
|
self.failIf(" " in cleaned_name)
|
|
self.failIf("é" in cleaned_name)
|
|
|
|
def test_grouper(self):
|
|
data = range(0, 10)
|
|
groups = grouper(data, count=2)
|
|
self.assertEquals(len(groups), 5)
|
|
data = range(0, 20)
|
|
groups = grouper(data, count=2)
|
|
self.assertEquals(len(groups), 10)
|
|
self.assertNotEqual(data, groups)
|
|
self.assertNotEqual(len(data), len(groups))
|
|
|
|
def test_direct_load_rom(self):
|
|
rom = self.rom
|
|
self.assertEqual(len(rom), 2097152)
|
|
self.failUnless(isinstance(rom, RomStr))
|
|
|
|
def test_load_rom(self):
|
|
global rom
|
|
rom = None
|
|
load_rom()
|
|
self.failIf(rom == None)
|
|
rom = RomStr(None)
|
|
load_rom()
|
|
self.failIf(rom == RomStr(None))
|
|
|
|
def test_load_asm(self):
|
|
asm = load_asm()
|
|
joined_lines = "\n".join(asm)
|
|
self.failUnless("SECTION" in joined_lines)
|
|
self.failUnless("bank" in joined_lines)
|
|
self.failUnless(isinstance(asm, AsmList))
|
|
|
|
def test_rom_file_existence(self):
|
|
"ROM file must exist"
|
|
self.failUnless("baserom.gbc" in os.listdir("../"))
|
|
|
|
def test_rom_md5(self):
|
|
"ROM file must have the correct md5 sum"
|
|
rom = self.rom
|
|
correct = "9f2922b235a5eeb78d65594e82ef5dde"
|
|
md5 = hashlib.md5()
|
|
md5.update(rom)
|
|
md5sum = md5.hexdigest()
|
|
self.assertEqual(md5sum, correct)
|
|
|
|
def test_bizarre_http_presence(self):
|
|
rom_segment = self.rom[0x112116:0x112116+8]
|
|
self.assertEqual(rom_segment, "HTTP/1.0")
|
|
|
|
def test_rom_interval(self):
|
|
address = 0x100
|
|
interval = 10
|
|
correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
|
|
'0xed', '0x66', '0x66', '0xcc', '0xd']
|
|
byte_strings = rom_interval(address, interval, strings=True)
|
|
self.assertEqual(byte_strings, correct_strings)
|
|
correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
|
|
ints = rom_interval(address, interval, strings=False)
|
|
self.assertEqual(ints, correct_ints)
|
|
|
|
def test_rom_until(self):
|
|
address = 0x1337
|
|
byte = 0x13
|
|
bytes = rom_until(address, byte, strings=True)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == '0xd5')
|
|
bytes = rom_until(address, byte, strings=False)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == 0xd5)
|
|
|
|
def test_how_many_until(self):
|
|
how_many = how_many_until(chr(0x13), 0x1337)
|
|
self.assertEqual(how_many, 3)
|
|
|
|
def test_calculate_bank(self):
|
|
self.failUnless(calculate_bank(0x8000) == 2)
|
|
self.failUnless(calculate_bank("0x9000") == 2)
|
|
self.failUnless(calculate_bank(0) == 0)
|
|
for address in [0x4000, 0x5000, 0x6000, 0x7000]:
|
|
self.assertRaises(Exception, calculate_bank, address)
|
|
|
|
def test_calculate_pointer(self):
|
|
# for offset <= 0x4000
|
|
self.assertEqual(calculate_pointer(0x0000), 0x0000)
|
|
self.assertEqual(calculate_pointer(0x3FFF), 0x3FFF)
|
|
# for 0x4000 <= offset <= 0x7FFFF
|
|
self.assertEqual(calculate_pointer(0x430F, bank=5), 0x1430F)
|
|
# for offset >= 0x7FFF
|
|
self.assertEqual(calculate_pointer(0x8FFF, bank=6), calculate_pointer(0x8FFF, bank=7))
|
|
|
|
def test_calculate_pointer_from_bytes_at(self):
|
|
addr1 = calculate_pointer_from_bytes_at(0x100, bank=False)
|
|
self.assertEqual(addr1, 0xc300)
|
|
addr2 = calculate_pointer_from_bytes_at(0x100, bank=True)
|
|
self.assertEqual(addr2, 0x2ec3)
|
|
|
|
def test_rom_text_at(self):
|
|
self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0")
|
|
|
|
def test_translate_command_byte(self):
|
|
self.failUnless(translate_command_byte(crystal=0x0) == 0x0)
|
|
self.failUnless(translate_command_byte(crystal=0x10) == 0x10)
|
|
self.failUnless(translate_command_byte(crystal=0x40) == 0x40)
|
|
self.failUnless(translate_command_byte(gold=0x0) == 0x0)
|
|
self.failUnless(translate_command_byte(gold=0x10) == 0x10)
|
|
self.failUnless(translate_command_byte(gold=0x40) == 0x40)
|
|
self.assertEqual(translate_command_byte(gold=0x0), translate_command_byte(crystal=0x0))
|
|
self.failUnless(translate_command_byte(gold=0x52) == 0x53)
|
|
self.failUnless(translate_command_byte(gold=0x53) == 0x54)
|
|
self.failUnless(translate_command_byte(crystal=0x53) == 0x52)
|
|
self.failUnless(translate_command_byte(crystal=0x52) == None)
|
|
self.assertRaises(Exception, translate_command_byte, None, gold=0xA4)
|
|
|
|
def test_pksv_integrity(self):
|
|
"does pksv_gs look okay?"
|
|
self.assertEqual(pksv_gs[0x00], "2call")
|
|
self.assertEqual(pksv_gs[0x2D], "givepoke")
|
|
self.assertEqual(pksv_gs[0x85], "waitbutton")
|
|
self.assertEqual(pksv_crystal[0x00], "2call")
|
|
self.assertEqual(pksv_crystal[0x86], "waitbutton")
|
|
self.assertEqual(pksv_crystal[0xA2], "credits")
|
|
|
|
def test_chars_integrity(self):
|
|
self.assertEqual(chars[0x80], "A")
|
|
self.assertEqual(chars[0xA0], "a")
|
|
self.assertEqual(chars[0xF0], "¥")
|
|
self.assertEqual(jap_chars[0x44], "ぱ")
|
|
|
|
def test_map_names_integrity(self):
|
|
def map_name(map_group, map_id): return map_names[map_group][map_id]["name"]
|
|
self.assertEqual(map_name(2, 7), "Mahogany Town")
|
|
self.assertEqual(map_name(3, 0x34), "Ilex Forest")
|
|
self.assertEqual(map_name(7, 0x11), "Cerulean City")
|
|
|
|
def test_load_map_group_offsets(self):
|
|
addresses = load_map_group_offsets()
|
|
self.assertEqual(len(addresses), 26, msg="there should be 26 map groups")
|
|
addresses = load_map_group_offsets()
|
|
self.assertEqual(len(addresses), 26, msg="there should still be 26 map groups")
|
|
self.assertIn(0x94034, addresses)
|
|
for address in addresses:
|
|
self.assertGreaterEqual(address, 0x4000)
|
|
self.failIf(0x4000 <= address <= 0x7FFF)
|
|
self.failIf(address <= 0x4000)
|
|
|
|
def test_index(self):
|
|
self.assertTrue(index([1,2,3,4], lambda f: True) == 0)
|
|
self.assertTrue(index([1,2,3,4], lambda f: f==3) == 2)
|
|
|
|
def test_get_pokemon_constant_by_id(self):
|
|
x = get_pokemon_constant_by_id
|
|
self.assertEqual(x(1), "BULBASAUR")
|
|
self.assertEqual(x(151), "MEW")
|
|
self.assertEqual(x(250), "HO_OH")
|
|
|
|
def test_find_item_label_by_id(self):
|
|
x = find_item_label_by_id
|
|
self.assertEqual(x(249), "HM_07")
|
|
self.assertEqual(x(173), "BERRY")
|
|
self.assertEqual(x(45), None)
|
|
|
|
def test_generate_item_constants(self):
|
|
x = generate_item_constants
|
|
r = x()
|
|
self.failUnless("HM_07" in r)
|
|
self.failUnless("EQU" in r)
|
|
|
|
def test_get_label_for(self):
|
|
global all_labels
|
|
temp = copy(all_labels)
|
|
# this is basd on the format defined in get_labels_between
|
|
all_labels = [{"label": "poop", "address": 0x5,
|
|
"offset": 0x5, "bank": 0,
|
|
"line_number": 2
|
|
}]
|
|
self.assertEqual(get_label_for(5), "poop")
|
|
all_labels = temp
|
|
|
|
def test_generate_map_constant_labels(self):
|
|
ids = generate_map_constant_labels()
|
|
self.assertEqual(ids[0]["label"], "OLIVINE_POKECENTER_1F")
|
|
self.assertEqual(ids[1]["label"], "OLIVINE_GYM")
|
|
|
|
def test_get_id_for_map_constant_label(self):
|
|
global map_internal_ids
|
|
map_internal_ids = generate_map_constant_labels()
|
|
self.assertEqual(get_id_for_map_constant_label("OLIVINE_GYM"), 1)
|
|
self.assertEqual(get_id_for_map_constant_label("OLIVINE_POKECENTER_1F"), 0)
|
|
|
|
def test_get_map_constant_label_by_id(self):
|
|
global map_internal_ids
|
|
map_internal_ids = generate_map_constant_labels()
|
|
self.assertEqual(get_map_constant_label_by_id(0), "OLIVINE_POKECENTER_1F")
|
|
self.assertEqual(get_map_constant_label_by_id(1), "OLIVINE_GYM")
|
|
|
|
def test_is_valid_address(self):
|
|
self.assertTrue(is_valid_address(0))
|
|
self.assertTrue(is_valid_address(1))
|
|
self.assertTrue(is_valid_address(10))
|
|
self.assertTrue(is_valid_address(100))
|
|
self.assertTrue(is_valid_address(1000))
|
|
self.assertTrue(is_valid_address(10000))
|
|
self.assertFalse(is_valid_address(2097153))
|
|
self.assertFalse(is_valid_address(2098000))
|
|
addresses = [random.randrange(0,2097153) for i in range(0, 9+1)]
|
|
for address in addresses:
|
|
self.assertTrue(is_valid_address(address))
|
|
|
|
class TestIntervalMap(unittest.TestCase):
|
|
def test_intervals(self):
|
|
i = IntervalMap()
|
|
first = "hello world"
|
|
second = "testing 123"
|
|
i[0:5] = first
|
|
i[5:10] = second
|
|
self.assertEqual(i[0], first)
|
|
self.assertEqual(i[1], first)
|
|
self.assertNotEqual(i[5], first)
|
|
self.assertEqual(i[6], second)
|
|
i[3:10] = second
|
|
self.assertEqual(i[3], second)
|
|
self.assertNotEqual(i[4], first)
|
|
|
|
def test_items(self):
|
|
i = IntervalMap()
|
|
first = "hello world"
|
|
second = "testing 123"
|
|
i[0:5] = first
|
|
i[5:10] = second
|
|
results = list(i.items())
|
|
self.failUnless(len(results) == 2)
|
|
self.assertEqual(results[0], ((0, 5), "hello world"))
|
|
self.assertEqual(results[1], ((5, 10), "testing 123"))
|
|
|
|
class TestRomStr(unittest.TestCase):
|
|
"""RomStr is a class that should act exactly like str()
|
|
except that it never shows the contents of it string
|
|
unless explicitly forced"""
|
|
sample_text = "hello world!"
|
|
sample = None
|
|
|
|
def setUp(self):
|
|
if self.sample == None:
|
|
self.__class__.sample = RomStr(self.sample_text)
|
|
|
|
def test_equals(self):
|
|
"check if RomStr() == str()"
|
|
self.assertEquals(self.sample_text, self.sample)
|
|
|
|
def test_not_equal(self):
|
|
"check if RomStr('a') != RomStr('b')"
|
|
self.assertNotEqual(RomStr('a'), RomStr('b'))
|
|
|
|
def test_appending(self):
|
|
"check if RomStr()+'a'==str()+'a'"
|
|
self.assertEquals(self.sample_text+'a', self.sample+'a')
|
|
|
|
def test_conversion(self):
|
|
"check if RomStr() -> str() works"
|
|
self.assertEquals(str(self.sample), self.sample_text)
|
|
|
|
def test_inheritance(self):
|
|
self.failUnless(issubclass(RomStr, str))
|
|
|
|
def test_length(self):
|
|
self.assertEquals(len(self.sample_text), len(self.sample))
|
|
self.assertEquals(len(self.sample_text), self.sample.length())
|
|
self.assertEquals(len(self.sample), self.sample.length())
|
|
|
|
def test_rom_interval(self):
|
|
global rom
|
|
load_rom()
|
|
address = 0x100
|
|
interval = 10
|
|
correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
|
|
'0xed', '0x66', '0x66', '0xcc', '0xd']
|
|
byte_strings = rom.interval(address, interval, strings=True)
|
|
self.assertEqual(byte_strings, correct_strings)
|
|
correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
|
|
ints = rom.interval(address, interval, strings=False)
|
|
self.assertEqual(ints, correct_ints)
|
|
|
|
def test_rom_until(self):
|
|
global rom
|
|
load_rom()
|
|
address = 0x1337
|
|
byte = 0x13
|
|
bytes = rom.until(address, byte, strings=True)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == '0xd5')
|
|
bytes = rom.until(address, byte, strings=False)
|
|
self.failUnless(len(bytes) == 3)
|
|
self.failUnless(bytes[0] == 0xd5)
|
|
|
|
class TestAsmList(unittest.TestCase):
|
|
"""AsmList is a class that should act exactly like list()
|
|
except that it never shows the contents of its list
|
|
unless explicitly forced"""
|
|
|
|
def test_equals(self):
|
|
base = [1,2,3]
|
|
asm = AsmList(base)
|
|
self.assertEquals(base, asm)
|
|
self.assertEquals(asm, base)
|
|
self.assertEquals(base, list(asm))
|
|
|
|
def test_inheritance(self):
|
|
self.failUnless(issubclass(AsmList, list))
|
|
|
|
def test_length(self):
|
|
base = range(0, 10)
|
|
asm = AsmList(base)
|
|
self.assertEquals(len(base), len(asm))
|
|
self.assertEquals(len(base), asm.length())
|
|
self.assertEquals(len(base), len(list(asm)))
|
|
self.assertEquals(len(asm), asm.length())
|
|
|
|
def test_remove_quoted_text(self):
|
|
x = remove_quoted_text
|
|
self.assertEqual(x("hello world"), "hello world")
|
|
self.assertEqual(x("hello \"world\""), "hello ")
|
|
input = 'hello world "testing 123"'
|
|
self.assertNotEqual(x(input), input)
|
|
input = "hello world 'testing 123'"
|
|
self.assertNotEqual(x(input), input)
|
|
self.failIf("testing" in x(input))
|
|
|
|
def test_line_has_comment_address(self):
|
|
x = line_has_comment_address
|
|
self.assertFalse(x(""))
|
|
self.assertFalse(x(";"))
|
|
self.assertFalse(x(";;;"))
|
|
self.assertFalse(x(":;"))
|
|
self.assertFalse(x(":;:"))
|
|
self.assertFalse(x(";:"))
|
|
self.assertFalse(x(" "))
|
|
self.assertFalse(x("".join(" " * 5)))
|
|
self.assertFalse(x("".join(" " * 10)))
|
|
self.assertFalse(x("hello world"))
|
|
self.assertFalse(x("hello_world"))
|
|
self.assertFalse(x("hello_world:"))
|
|
self.assertFalse(x("hello_world:;"))
|
|
self.assertFalse(x("hello_world: ;"))
|
|
self.assertFalse(x("hello_world: ; "))
|
|
self.assertFalse(x("hello_world: ;" + "".join(" " * 5)))
|
|
self.assertFalse(x("hello_world: ;" + "".join(" " * 10)))
|
|
self.assertTrue(x(";1"))
|
|
self.assertTrue(x(";F"))
|
|
self.assertTrue(x(";$00FF"))
|
|
self.assertTrue(x(";0x00FF"))
|
|
self.assertTrue(x("; 0x00FF"))
|
|
self.assertTrue(x(";$3:$300"))
|
|
self.assertTrue(x(";0x3:$300"))
|
|
self.assertTrue(x(";$3:0x300"))
|
|
self.assertTrue(x(";3:300"))
|
|
self.assertTrue(x(";3:FFAA"))
|
|
self.assertFalse(x('hello world "how are you today;0x1"'))
|
|
self.assertTrue(x('hello world "how are you today:0x1";1'))
|
|
returnable = {}
|
|
self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5))
|
|
self.assertTrue(returnable["address"] == 0x14050)
|
|
|
|
def test_line_has_label(self):
|
|
x = line_has_label
|
|
self.assertTrue(x("hi:"))
|
|
self.assertTrue(x("Hello: "))
|
|
self.assertTrue(x("MyLabel: ; test xyz"))
|
|
self.assertFalse(x(":"))
|
|
self.assertFalse(x(";HelloWorld:"))
|
|
self.assertFalse(x("::::"))
|
|
self.assertFalse(x(":;:;:;:::"))
|
|
|
|
def test_get_label_from_line(self):
|
|
x = get_label_from_line
|
|
self.assertEqual(x("HelloWorld: "), "HelloWorld")
|
|
self.assertEqual(x("HiWorld:"), "HiWorld")
|
|
self.assertEqual(x("HiWorld"), None)
|
|
|
|
def test_find_labels_without_addresses(self):
|
|
global asm
|
|
asm = ["hello_world: ; 0x1", "hello_world2: ;"]
|
|
labels = find_labels_without_addresses()
|
|
self.failUnless(labels[0]["label"] == "hello_world2")
|
|
asm = ["hello world: ;1", "hello_world: ;2"]
|
|
labels = find_labels_without_addresses()
|
|
self.failUnless(len(labels) == 0)
|
|
asm = None
|
|
|
|
def test_get_labels_between(self):
|
|
global asm
|
|
x = get_labels_between#(start_line_id, end_line_id, bank)
|
|
asm = ["HelloWorld: ;1",
|
|
"hi:",
|
|
"no label on this line",
|
|
]
|
|
labels = x(0, 2, 0x12)
|
|
self.assertEqual(len(labels), 1)
|
|
self.assertEqual(labels[0]["label"], "HelloWorld")
|
|
del asm
|
|
|
|
# this test takes a lot of time :(
|
|
def xtest_scan_for_predefined_labels(self):
|
|
# label keys: line_number, bank, label, offset, address
|
|
load_asm()
|
|
all_labels = scan_for_predefined_labels()
|
|
label_names = [x["label"] for x in all_labels]
|
|
self.assertIn("GetFarByte", label_names)
|
|
self.assertIn("AddNTimes", label_names)
|
|
self.assertIn("CheckShininess", label_names)
|
|
|
|
def test_write_all_labels(self):
|
|
"""dumping json into a file"""
|
|
filename = "test_labels.json"
|
|
# remove the current file
|
|
if os.path.exists(filename):
|
|
os.system("rm " + filename)
|
|
# make up some labels
|
|
labels = []
|
|
# fake label 1
|
|
label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10}
|
|
labels.append(label)
|
|
# fake label 2
|
|
label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A}
|
|
labels.append(label)
|
|
# dump to file
|
|
write_all_labels(labels, filename=filename)
|
|
# open the file and read the contents
|
|
file_handler = open(filename, "r")
|
|
contents = file_handler.read()
|
|
file_handler.close()
|
|
# parse into json
|
|
obj = json.read(contents)
|
|
# begin testing
|
|
self.assertEqual(len(obj), len(labels))
|
|
self.assertEqual(len(obj), 2)
|
|
self.assertEqual(obj, labels)
|
|
|
|
def test_isolate_incbins(self):
|
|
global asm
|
|
asm = ["123", "456", "789", "abc", "def", "ghi",
|
|
'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
|
|
"jkl",
|
|
'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
|
|
lines = isolate_incbins()
|
|
self.assertIn(asm[6], lines)
|
|
self.assertIn(asm[8], lines)
|
|
for line in lines:
|
|
self.assertIn("baserom", line)
|
|
|
|
def test_process_incbins(self):
|
|
global incbin_lines, processed_incbins, asm
|
|
incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
|
|
'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
|
|
asm = copy(incbin_lines)
|
|
asm.insert(1, "some other random line")
|
|
processed_incbins = process_incbins()
|
|
self.assertEqual(len(processed_incbins), len(incbin_lines))
|
|
self.assertEqual(processed_incbins[0]["line"], incbin_lines[0])
|
|
self.assertEqual(processed_incbins[2]["line"], incbin_lines[1])
|
|
|
|
def test_reset_incbins(self):
|
|
global asm, incbin_lines, processed_incbins
|
|
# temporarily override the functions
|
|
global load_asm, isolate_incbins, process_incbins
|
|
temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins
|
|
def load_asm(): pass
|
|
def isolate_incbins(): pass
|
|
def process_incbins(): pass
|
|
# call reset
|
|
reset_incbins()
|
|
# check the results
|
|
self.assertTrue(asm == [] or asm == None)
|
|
self.assertTrue(incbin_lines == [])
|
|
self.assertTrue(processed_incbins == {})
|
|
# reset the original functions
|
|
load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3
|
|
|
|
def test_find_incbin_to_replace_for(self):
|
|
global asm, incbin_lines, processed_incbins
|
|
asm = ['first line', 'second line', 'third line',
|
|
'INCBIN "baserom.gbc",$90,$200 - $90',
|
|
'fifth line', 'last line']
|
|
isolate_incbins()
|
|
process_incbins()
|
|
line_num = find_incbin_to_replace_for(0x100)
|
|
# must be the 4th line (the INBIN line)
|
|
self.assertEqual(line_num, 3)
|
|
|
|
def test_split_incbin_line_into_three(self):
|
|
global asm, incbin_lines, processed_incbins
|
|
asm = ['first line', 'second line', 'third line',
|
|
'INCBIN "baserom.gbc",$90,$200 - $90',
|
|
'fifth line', 'last line']
|
|
isolate_incbins()
|
|
process_incbins()
|
|
content = split_incbin_line_into_three(3, 0x100, 10)
|
|
# must end up with three INCBINs in output
|
|
self.failUnless(content.count("INCBIN") == 3)
|
|
|
|
def test_analyze_intervals(self):
|
|
global asm, incbin_lines, processed_incbins
|
|
asm, incbin_lines, processed_incbins = None, [], {}
|
|
asm = ['first line', 'second line', 'third line',
|
|
'INCBIN "baserom.gbc",$90,$200 - $90',
|
|
'fifth line', 'last line',
|
|
'INCBIN "baserom.gbc",$33F,$4000 - $33F']
|
|
isolate_incbins()
|
|
process_incbins()
|
|
largest = analyze_intervals()
|
|
self.assertEqual(largest[0]["line_number"], 6)
|
|
self.assertEqual(largest[0]["line"], asm[6])
|
|
self.assertEqual(largest[1]["line_number"], 3)
|
|
self.assertEqual(largest[1]["line"], asm[3])
|
|
|
|
def test_generate_diff_insert(self):
|
|
global asm
|
|
asm = ['first line', 'second line', 'third line',
|
|
'INCBIN "baserom.gbc",$90,$200 - $90',
|
|
'fifth line', 'last line',
|
|
'INCBIN "baserom.gbc",$33F,$4000 - $33F']
|
|
diff = generate_diff_insert(0, "the real first line", debug=False)
|
|
self.assertIn("the real first line", diff)
|
|
self.assertIn("INCBIN", diff)
|
|
self.assertNotIn("No newline at end of file", diff)
|
|
self.assertIn("+"+asm[1], diff)
|
|
|
|
class TestMapParsing(unittest.TestCase):
|
|
def test_parse_all_map_headers(self):
|
|
global parse_map_header_at, old_parse_map_header_at, counter
|
|
counter = 0
|
|
for k in map_names.keys():
|
|
if "offset" not in map_names[k].keys():
|
|
map_names[k]["offset"] = 0
|
|
temp = parse_map_header_at
|
|
temp2 = old_parse_map_header_at
|
|
def parse_map_header_at(address, map_group=None, map_id=None, debug=False):
|
|
global counter
|
|
counter += 1
|
|
return {}
|
|
old_parse_map_header_at = parse_map_header_at
|
|
parse_all_map_headers(debug=False)
|
|
# parse_all_map_headers is currently doing it 2x
|
|
# because of the new/old map header parsing routines
|
|
self.assertEqual(counter, 388 * 2)
|
|
parse_map_header_at = temp
|
|
old_parse_map_header_at = temp2
|
|
|
|
class TestTextScript(unittest.TestCase):
|
|
"""for testing 'in-script' commands, etc."""
|
|
#def test_to_asm(self):
|
|
# pass # or raise NotImplementedError, bryan_message
|
|
#def test_find_addresses(self):
|
|
# pass # or raise NotImplementedError, bryan_message
|
|
#def test_parse_text_at(self):
|
|
# pass # or raise NotImplementedError, bryan_message
|
|
|
|
class TestEncodedText(unittest.TestCase):
|
|
"""for testing chars-table encoded text chunks"""
|
|
|
|
def test_process_00_subcommands(self):
|
|
g = process_00_subcommands(0x197186, 0x197186+601, debug=False)
|
|
self.assertEqual(len(g), 42)
|
|
self.assertEqual(len(g[0]), 13)
|
|
self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81])
|
|
|
|
def test_parse_text_at2(self):
|
|
oakspeech = parse_text_at2(0x197186, 601, debug=False)
|
|
self.assertIn("encyclopedia", oakspeech)
|
|
self.assertIn("researcher", oakspeech)
|
|
self.assertIn("dependable", oakspeech)
|
|
|
|
def test_parse_text_engine_script_at(self):
|
|
p = parse_text_engine_script_at(0x197185, debug=False)
|
|
self.assertEqual(len(p.commands), 2)
|
|
self.assertEqual(len(p.commands[0]["lines"]), 41)
|
|
|
|
# don't really care about these other two
|
|
def test_parse_text_from_bytes(self): pass
|
|
def test_parse_text_at(self): pass
|
|
|
|
class TestScript(unittest.TestCase):
|
|
"""for testing parse_script_engine_script_at and script parsing in
|
|
general. Script should be a class."""
|
|
#def test_parse_script_engine_script_at(self):
|
|
# pass # or raise NotImplementedError, bryan_message
|
|
|
|
def test_find_all_text_pointers_in_script_engine_script(self):
|
|
address = 0x197637 # 0x197634
|
|
script = parse_script_engine_script_at(address, debug=False)
|
|
bank = calculate_bank(address)
|
|
r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False)
|
|
results = list(r)
|
|
self.assertIn(0x197661, results)
|
|
|
|
class TestLabel(unittest.TestCase):
|
|
def test_label_making(self):
|
|
line_number = 2
|
|
address = 0xf0c0
|
|
label_name = "poop"
|
|
l = Label(name=label_name, address=address, line_number=line_number)
|
|
self.failUnless(hasattr(l, "name"))
|
|
self.failUnless(hasattr(l, "address"))
|
|
self.failUnless(hasattr(l, "line_number"))
|
|
self.failIf(isinstance(l.address, str))
|
|
self.failIf(isinstance(l.line_number, str))
|
|
self.failUnless(isinstance(l.name, str))
|
|
self.assertEqual(l.line_number, line_number)
|
|
self.assertEqual(l.name, label_name)
|
|
self.assertEqual(l.address, address)
|
|
|
|
class TestByteParams(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
load_rom()
|
|
cls.address = 10
|
|
cls.sbp = SingleByteParam(address=cls.address)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.sbp
|
|
|
|
def test__init__(self):
|
|
self.assertEqual(self.sbp.size, 1)
|
|
self.assertEqual(self.sbp.address, self.address)
|
|
|
|
def test_parse(self):
|
|
self.sbp.parse()
|
|
self.assertEqual(str(self.sbp.byte), str(45))
|
|
|
|
def test_to_asm(self):
|
|
self.assertEqual(self.sbp.to_asm(), "$2d")
|
|
self.sbp.should_be_decimal = True
|
|
self.assertEqual(self.sbp.to_asm(), str(45))
|
|
|
|
# HexByte and DollarSignByte are the same now
|
|
def test_HexByte_to_asm(self):
|
|
h = HexByte(address=10)
|
|
a = h.to_asm()
|
|
self.assertEqual(a, "$2d")
|
|
|
|
def test_DollarSignByte_to_asm(self):
|
|
d = DollarSignByte(address=10)
|
|
a = d.to_asm()
|
|
self.assertEqual(a, "$2d")
|
|
|
|
def test_ItemLabelByte_to_asm(self):
|
|
i = ItemLabelByte(address=433)
|
|
self.assertEqual(i.byte, 54)
|
|
self.assertEqual(i.to_asm(), "COIN_CASE")
|
|
self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d")
|
|
|
|
def test_DecimalParam_to_asm(self):
|
|
d = DecimalParam(address=10)
|
|
x = d.to_asm()
|
|
self.assertEqual(x, str(0x2d))
|
|
|
|
class TestMultiByteParam(unittest.TestCase):
|
|
def setup_for(self, somecls, byte_size=2, address=443, **kwargs):
|
|
self.cls = somecls(address=address, size=byte_size, **kwargs)
|
|
self.assertEqual(self.cls.address, address)
|
|
self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, strings=False))
|
|
self.assertEqual(self.cls.size, byte_size)
|
|
|
|
def test_two_byte_param(self):
|
|
self.setup_for(MultiByteParam, byte_size=2)
|
|
self.assertEqual(self.cls.to_asm(), "$f0c0")
|
|
|
|
def test_three_byte_param(self):
|
|
self.setup_for(MultiByteParam, byte_size=3)
|
|
|
|
def test_PointerLabelParam_no_bank(self):
|
|
self.setup_for(PointerLabelParam, bank=None)
|
|
# assuming no label at this location..
|
|
self.assertEqual(self.cls.to_asm(), "$f0c0")
|
|
global all_labels
|
|
# hm.. maybe all_labels should be using a class?
|
|
all_labels = [{"label": "poop", "address": 0xf0c0,
|
|
"offset": 0xf0c0, "bank": 0,
|
|
"line_number": 2
|
|
}]
|
|
self.assertEqual(self.cls.to_asm(), "poop")
|
|
|
|
class TestPostParsing: #(unittest.TestCase):
|
|
"""tests that must be run after parsing all maps"""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
run_main()
|
|
|
|
def test_signpost_counts(self):
|
|
self.assertEqual(len(map_names[1][1]["signposts"]), 0)
|
|
self.assertEqual(len(map_names[1][2]["signposts"]), 2)
|
|
self.assertEqual(len(map_names[10][5]["signposts"]), 7)
|
|
|
|
def test_warp_counts(self):
|
|
self.assertEqual(map_names[10][5]["warp_count"], 9)
|
|
self.assertEqual(map_names[18][5]["warp_count"], 3)
|
|
self.assertEqual(map_names[15][1]["warp_count"], 2)
|
|
|
|
def test_map_sizes(self):
|
|
self.assertEqual(map_names[15][1]["height"], 18)
|
|
self.assertEqual(map_names[15][1]["width"], 10)
|
|
self.assertEqual(map_names[7][1]["height"], 4)
|
|
self.assertEqual(map_names[7][1]["width"], 4)
|
|
|
|
def test_map_connection_counts(self):
|
|
self.assertEqual(map_names[7][1]["connections"], 0)
|
|
self.assertEqual(map_names[10][1]["connections"], 12)
|
|
self.assertEqual(map_names[10][2]["connections"], 12)
|
|
self.assertEqual(map_names[11][1]["connections"], 9) # or 13?
|
|
|
|
def test_second_map_header_address(self):
|
|
self.assertEqual(map_names[11][1]["second_map_header_address"], 0x9509c)
|
|
self.assertEqual(map_names[1][5]["second_map_header_address"], 0x95bd0)
|
|
|
|
def test_event_address(self):
|
|
self.assertEqual(map_names[17][5]["event_address"], 0x194d67)
|
|
self.assertEqual(map_names[23][3]["event_address"], 0x1a9ec9)
|
|
|
|
def test_people_event_counts(self):
|
|
self.assertEqual(len(map_names[23][3]["people_events"]), 4)
|
|
self.assertEqual(len(map_names[10][3]["people_events"]), 9)
|
|
|
|
class TestMetaTesting(unittest.TestCase):
|
|
"""test whether or not i am finding at least
|
|
some of the tests in this file"""
|
|
tests = None
|
|
|
|
def setUp(self):
|
|
if self.tests == None:
|
|
self.__class__.tests = assemble_test_cases()
|
|
|
|
def test_assemble_test_cases_count(self):
|
|
"does assemble_test_cases find some tests?"
|
|
self.failUnless(len(self.tests) > 0)
|
|
|
|
def test_assemble_test_cases_inclusion(self):
|
|
"is this class found by assemble_test_cases?"
|
|
# i guess it would have to be for this to be running?
|
|
self.failUnless(self.__class__ in self.tests)
|
|
|
|
def test_assemble_test_cases_others(self):
|
|
"test other inclusions for assemble_test_cases"
|
|
self.failUnless(TestRomStr in self.tests)
|
|
self.failUnless(TestCram in self.tests)
|
|
|
|
def test_check_has_test(self):
|
|
self.failUnless(check_has_test("beaver", ["test_beaver"]))
|
|
self.failUnless(check_has_test("beaver", ["test_beaver_2"]))
|
|
self.failIf(check_has_test("beaver_1", ["test_beaver"]))
|
|
|
|
def test_find_untested_methods(self):
|
|
untested = find_untested_methods()
|
|
# the return type must be an iterable
|
|
self.failUnless(hasattr(untested, "__iter__"))
|
|
#.. basically, a list
|
|
self.failUnless(isinstance(untested, list))
|
|
|
|
def test_find_untested_methods_method(self):
|
|
"""create a function and see if it is found"""
|
|
# setup a function in the global namespace
|
|
global some_random_test_method
|
|
# define the method
|
|
def some_random_test_method(): pass
|
|
# first make sure it is in the global scope
|
|
members = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
|
|
func_names = [functuple[0] for functuple in members]
|
|
self.assertIn("some_random_test_method", func_names)
|
|
# test whether or not it is found by find_untested_methods
|
|
untested = find_untested_methods()
|
|
self.assertIn("some_random_test_method", untested)
|
|
# remove the test method from the global namespace
|
|
del some_random_test_method
|
|
|
|
def test_load_tests(self):
|
|
loader = unittest.TestLoader()
|
|
suite = load_tests(loader, None, None)
|
|
suite._tests[0]._testMethodName
|
|
membership_test = lambda member: \
|
|
inspect.isclass(member) and issubclass(member, unittest.TestCase)
|
|
tests = inspect.getmembers(sys.modules[__name__], membership_test)
|
|
classes = [x[1] for x in tests]
|
|
for test in suite._tests:
|
|
self.assertIn(test.__class__, classes)
|
|
|
|
def test_report_untested(self):
|
|
untested = find_untested_methods()
|
|
output = report_untested()
|
|
if len(untested) > 0:
|
|
self.assertIn("NOT TESTED", output)
|
|
for name in untested:
|
|
self.assertIn(name, output)
|
|
elif len(untested) == 0:
|
|
self.assertNotIn("NOT TESTED", output)
|
|
|
|
def assemble_test_cases():
|
|
"""finds classes that inherit from unittest.TestCase
|
|
because i am too lazy to remember to add them to a
|
|
global list of tests for the suite runner"""
|
|
classes = []
|
|
clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
|
for (name, some_class) in clsmembers:
|
|
if issubclass(some_class, unittest.TestCase):
|
|
classes.append(some_class)
|
|
return classes
|
|
|
|
def load_tests(loader, tests, pattern):
|
|
suite = unittest.TestSuite()
|
|
for test_class in assemble_test_cases():
|
|
tests = loader.loadTestsFromTestCase(test_class)
|
|
suite.addTests(tests)
|
|
return suite
|
|
|
|
def check_has_test(func_name, tested_names):
|
|
"""checks if there is a test dedicated to this function"""
|
|
if "test_"+func_name in tested_names:
|
|
return True
|
|
for name in tested_names:
|
|
if "test_"+func_name in name:
|
|
return True
|
|
return False
|
|
|
|
def find_untested_methods():
|
|
"""finds all untested functions in this module
|
|
by searching for method names in test case
|
|
method names."""
|
|
untested = []
|
|
avoid_funcs = ["main", "run_tests", "run_main", "copy", "deepcopy"]
|
|
test_funcs = []
|
|
# get a list of all classes in this module
|
|
classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
|
# for each class..
|
|
for (name, klass) in classes:
|
|
# only look at those that have tests
|
|
if issubclass(klass, unittest.TestCase):
|
|
# look at this class' methods
|
|
funcs = inspect.getmembers(klass, inspect.ismethod)
|
|
# for each method..
|
|
for (name2, func) in funcs:
|
|
# store the ones that begin with test_
|
|
if "test_" in name2 and name2[0:5] == "test_":
|
|
test_funcs.append([name2, func])
|
|
# assemble a list of all test method names (test_x, test_y, ..)
|
|
tested_names = [funcz[0] for funcz in test_funcs]
|
|
# now get a list of all functions in this module
|
|
funcs = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
|
|
# for each function..
|
|
for (name, func) in funcs:
|
|
# we don't care about some of these
|
|
if name in avoid_funcs: continue
|
|
# skip functions beginning with _
|
|
if name[0] == "_": continue
|
|
# check if this function has a test named after it
|
|
has_test = check_has_test(name, tested_names)
|
|
if not has_test:
|
|
untested.append(name)
|
|
return untested
|
|
|
|
def report_untested():
|
|
"""
|
|
This reports about untested functions in the global namespace. This was
|
|
originally in the crystal module, where it would list out the majority of
|
|
the functions. Maybe it should be moved back.
|
|
"""
|
|
untested = find_untested_methods()
|
|
output = "NOT TESTED: ["
|
|
first = True
|
|
for name in untested:
|
|
if first:
|
|
output += name
|
|
first = False
|
|
else: output += ", "+name
|
|
output += "]\n"
|
|
output += "total untested: " + str(len(untested))
|
|
return output
|
|
|
|
def run_tests(): # rather than unittest.main()
|
|
loader = unittest.TestLoader()
|
|
suite = load_tests(loader, None, None)
|
|
unittest.TextTestRunner(verbosity=2).run(suite)
|
|
print report_untested()
|
|
|
|
# run the unit tests when this file is executed directly
|
|
if __name__ == "__main__":
|
|
run_tests()
|
|
|