make unit tests work again

This commit is contained in:
Bryan Bishop 2012-04-18 23:20:28 -05:00
parent b7295080d5
commit b7cca3a548
1 changed files with 36 additions and 46 deletions

View File

@ -1095,20 +1095,20 @@ def find_all_text_pointers_in_script_engine_script(script, bank=None, debug=Fals
#TODO: recursively follow any jumps in the script
if script == None: return []
addresses = set()
for (k, command) in script.commands.items():
for (k, command) in enumerate(script.commands):
if debug:
print "command is: " + str(command)
if command["type"] == 0x4B:
addresses.add(command["pointer"])
elif command["type"] == 0x4C:
addresses.add(command["pointer"])
elif command["type"] == 0x51:
addresses.add(command["pointer"])
elif command["type"] == 0x53:
addresses.add(command["pointer"])
elif command["type"] == 0x64:
addresses.add(command["won_pointer"])
addresses.add(command["lost_pointer"])
if command.id == 0x4B:
addresses.add(command.params[0].parsed_address)
elif command.id == 0x4C:
addresses.add(command.params[0].parsed_address)
elif command.id == 0x51:
addresses.add(command.params[0].parsed_address)
elif command.id == 0x53:
addresses.add(command.params[0].parsed_address)
elif command.id == 0x64:
addresses.add(command.params[0].parsed_address)
addresses.add(command.params[1].parsed_address)
return addresses
def translate_command_byte(crystal=None, gold=None):
@ -1167,14 +1167,9 @@ class SingleByteParam():
else: return str(self.byte)
class HexByte(SingleByteParam):
def to_asm(self): return "$%.2x" % (self.byte)
class DollarSignByte(SingleByteParam):
#def to_asm(self): return "$%.2x"%self.byte
def to_asm(self): return hex(self.byte).replace("0x", "$")
HexByte=DollarSignByte
class ItemLabelByte(DollarSignByte):
def to_asm(self):
@ -1206,7 +1201,12 @@ class MultiByteParam():
raise Exception, "don't know how many bytes to read (size)"
self.parse()
def parse(self): self.bytes = rom_interval(self.address, self.size, strings=False)
def parse(self):
self.bytes = rom_interval(self.address, self.size, strings=False)
if hasattr(self, "bank"):
self.parsed_address = calculate_pointer_from_bytes_at(self.address, bank=self.bank)
else:
self.parsed_address = calculate_pointer_from_bytes_at(self.address, bank=None)
#you won't actually use this to_asm because it's too generic
#def to_asm(self): return ", ".join([(self.prefix+"%.2x")%x for x in self.bytes])
@ -1239,6 +1239,10 @@ class PointerLabelParam(MultiByteParam):
#continue instantiation.. self.bank will be set down the road
MultiByteParam.__init__(self, *args, **kwargs)
def parse(self):
self.parsed_address = calculate_pointer_from_bytes_at(self.address, bank=self.bank)
MultiByteParam.parse(self)
def to_asm(self):
bank = self.bank
#we pass bank= for whether or not to include a bank byte when reading
@ -1443,8 +1447,6 @@ class MovementPointerLabelParam(PointerLabelParam):
class MapDataPointerParam(PointerLabelParam):
pass
class Command:
"""
Note: when dumping to asm, anything in script_parse_table that directly
@ -4019,7 +4021,6 @@ for map_group_id in map_names.keys():
#generate map constants (like 1=PALLET_TOWN)
generate_map_constant_labels()
#### asm utilities ####
#these are pulled in from pokered/extras/analyze_incbins.py
@ -4284,7 +4285,8 @@ def get_label_for(address):
else:
return "AlreadyParsedNoDefaultUnknownLabel_" + hex(address)
return "NotYetParsed_"+hex(address)
#return "NotYetParsed_"+hex(address)
return "$%.2x"%(address)
def remove_quoted_text(line):
"""get rid of content inside quotes
@ -5106,38 +5108,25 @@ class TestAsmList(unittest.TestCase):
class TestMapParsing(unittest.TestCase):
#def test_parse_warp_bytes(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_xy_trigger_bytes(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_people_event_bytes(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_map_header_at(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_second_map_header_at(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_map_event_header_at(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_map_script_header_at(self):
# pass #or raise NotImplementedError, bryan_message
#def test_parse_map_header_by_id(self):
# pass #or raise NotImplementedError, bryan_message
def test_parse_all_map_headers(self):
global parse_map_header_at, counter
global parse_map_header_at, old_parse_map_header_at, counter
counter = 0
for k in map_names.keys():
if "offset" not in map_names[k].keys():
map_names[k]["offset"] = 0
temp = parse_map_header_at
temp2 = old_parse_map_header_at
def parse_map_header_at(address, map_group=None, map_id=None, debug=False):
global counter
counter += 1
return {}
old_parse_map_header_at = parse_map_header_at
parse_all_map_headers(debug=False)
self.assertEqual(counter, 388)
#parse_all_map_headers is currently doing it 2x
#because of the new/old map header parsing routines
self.assertEqual(counter, 388 * 2)
parse_map_header_at = temp
old_parse_map_header_at = temp2
class TestTextScript(unittest.TestCase):
"""for testing 'in-script' commands, etc."""
@ -5166,8 +5155,8 @@ class TestEncodedText(unittest.TestCase):
def test_parse_text_engine_script_at(self):
p = parse_text_engine_script_at(0x197185, debug=False)
self.assertEqual(len(p), 2)
self.assertEqual(len(p[0]["lines"]), 41)
self.assertEqual(len(p.commands), 2)
self.assertEqual(len(p.commands[0]["lines"]), 41)
#don't really care about these other two
def test_parse_text_from_bytes(self): pass
@ -5230,10 +5219,11 @@ class TestByteParams(unittest.TestCase):
self.sbp.should_be_decimal = True
self.assertEqual(self.sbp.to_asm(), str(45))
#HexByte and DollarSignByte are the same now
def test_HexByte_to_asm(self):
h = HexByte(address=10)
a = h.to_asm()
self.assertEqual(a, "0x2d")
self.assertEqual(a, "$2d")
def test_DollarSignByte_to_asm(self):
d = DollarSignByte(address=10)