From 3bd84c1dac0dc7085287bd6a7c822dfa2663cf71 Mon Sep 17 00:00:00 2001 From: Bryan Bishop Date: Sat, 24 Mar 2012 18:01:37 -0500 Subject: [PATCH] lots of asm-related code and tests --- extras/crystal.py | 520 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 520 insertions(+) diff --git a/extras/crystal.py b/extras/crystal.py index 786d7899a..e807c6fee 100644 --- a/extras/crystal.py +++ b/extras/crystal.py @@ -4405,6 +4405,463 @@ for map_group_id in map_names.keys(): #set the value in the original dictionary map_names[map_group_id][map_id]["label"] = cleaned_name +#### asm utilities #### +#these are pulled in from pokered/extras/analyze_incbins.py + +#store each line of source code here +asm = None + +#store each incbin line separately +incbin_lines = [] + +#storage for processed incbin lines +processed_incbins = {} + +def isolate_incbins(): + "find each incbin line" + global incbin_lines + incbin_lines = [] + for line in asm: + if line == "": continue + if line.count(" ") == len(line): continue + + #clean up whitespace at beginning of line + while line[0] == " ": + line = line[1:] + + if line[0:6] == "INCBIN" and "baserom.gbc" in line: + incbin_lines.append(line) + return incbin_lines + +def process_incbins(): + "parse incbin lines into memory" + global incbins + incbins = {} #reset + for incbin in incbin_lines: + processed_incbin = {} + + line_number = asm.index(incbin) + + partial_start = incbin[21:] + start = partial_start.split(",")[0].replace("$", "0x") + start = eval(start) + start_hex = hex(start).replace("0x", "$") + + partial_interval = incbin[21:].split(",")[1] + partial_interval = partial_interval.replace(";", "#") + partial_interval = partial_interval.replace("$", "0x").replace("0xx", "0x") + interval = eval(partial_interval) + interval_hex = hex(interval).replace("0x", "$").replace("x", "") + + end = start + interval + end_hex = hex(end).replace("0x", "$") + + processed_incbin = { + "line_number": line_number, + "line": incbin, + "start": start, + "interval": interval, + "end": end, + } + + #don't add this incbin if the interval is 0 + if interval != 0: + processed_incbins[line_number] = processed_incbin + +def reset_incbins(): + "reset asm before inserting another diff" + asm = None + incbin_lines = [] + processed_incbins = {} + load_asm() + isolate_incbins() + process_incbins() + +def find_incbin_to_replace_for(address, debug=False, rom_file="../baserom.gbc"): + """returns a line number for which incbin to edit + if you were to insert bytes into main.asm""" + if type(address) == str: address = int(address, 16) + if not (0 <= address <= os.lstat(rom_file).st_size): + raise IndexError, "address is out of bounds" + for incbin_key in processed_incbins.keys(): + incbin = processed_incbins[incbin_key] + start = incbin["start"] + end = incbin["end"] + if debug: + print "start is: " + str(start) + print "end is: " + str(end) + print "address is: " + str(type(address)) + print "checking.... " + hex(start) + " <= " + hex(address) + " <= " + hex(end) + if start <= address <= end: + return incbin_key + return None + +def split_incbin_line_into_three(line, start_address, byte_count): + """ + splits an incbin line into three pieces. + you can replace the middle one with the new content of length bytecount + + start_address: where you want to start inserting bytes + byte_count: how many bytes you will be inserting + """ + if type(start_address) == str: start_address = int(start_address, 16) + if not (0 <= start_address <= os.lstat(rom_file).st_size): + raise IndexError, "start_address is out of bounds" + if len(processed_incbins) == 0: + raise Exception, "processed_incbins must be populated" + + original_incbin = processed_incbins[line] + start = original_incbin["start"] + end = original_incbin["end"] + + #start, end1, end2 (to be printed as start, end1 - end2) + if start_address - start > 0: + first = (start, start_address, start) + else: + first = (None) #skip this one because we're not including anything + + #this is the one you will replace with whatever content + second = (start_address, byte_count) + + third = (start_address + byte_count, end - (start_address + byte_count)) + + output = "" + + if first: + output += "INCBIN \"baserom.gbc\",$" + hex(first[0])[2:] + ",$" + hex(first[1])[2:] + " - $" + hex(first[2])[2:] + "\n" + output += "INCBIN \"baserom.gbc\",$" + hex(second[0])[2:] + "," + str(byte_count) + "\n" + output += "INCBIN \"baserom.gbc\",$" + hex(third[0])[2:] + ",$" + hex(third[1])[2:] #no newline + return output + +def generate_diff_insert(line_number, newline): + original = "\n".join(line for line in asm) + newfile = deepcopy(asm) + newfile[line_number] = newline #possibly inserting multiple lines + newfile = "\n".join(line for line in newfile) + + original_filename = "ejroqjfoad.temp" + newfile_filename = "fjiqefo.temp" + + original_fh = open(original_filename, "w") + original_fh.write(original) + original_fh.close() + + newfile_fh = open(newfile_filename, "w") + newfile_fh.write(newfile) + newfile_fh.close() + + try: + diffcontent = subprocess.check_output("diff -u ../main.asm " + newfile_filename, shell=True) + except AttributeError, exc: + raise exc + except Exception, exc: + diffcontent = exc.output + + os.system("rm " + original_filename) + os.system("rm " + newfile_filename) + + return diffcontent + +def apply_diff(diff, try_fixing=True, do_compile=True): + print "... Applying diff." + + #write the diff to a file + fh = open("temp.patch", "w") + fh.write(diff) + fh.close() + + #apply the patch + os.system("cp ../main.asm ../main1.asm") + os.system("patch ../main.asm temp.patch") + + #remove the patch + os.system("rm temp.patch") + + #confirm it's working + if do_compile: + try: + subprocess.check_call("cd ../; make clean; LC_CTYPE=C make", shell=True) + return True + except Exception, exc: + if try_fixing: + os.system("mv ../main1.asm ../main.asm") + return False + +def index(seq, f): + """return the index of the first item in seq + where f(item) == True.""" + return next((i for i in xrange(len(seq)) if f(seq[i])), None) + +def is_probably_pointer(input): + try: + blah = int(input, 16) + return True + except: + return False + +def analyze_intervals(): + """find the largest baserom.gbc intervals""" + global asm, processed_incbins + if asm == None: + load_asm() + if processed_incbins == {}: + isolate_incbins() + process_incbins() + results = [] + ordered_keys = sorted(processed_incbins, key=lambda entry: processed_incbins[entry]["interval"]) + ordered_keys.reverse() + for key in ordered_keys: + results.append(processed_incbins[key]) + return results + +def write_all_labels(all_labels): + fh = open("labels.json", "w") + fh.write(json.dumps(all_labels)) + fh.close() + +def remove_quoted_text(line): + """get rid of content inside quotes + and also removes the quotes from the input string""" + while line.count("\"") % 2 == 0 and line.count("\"") > 0: + first = line.find("\"") + second = line.find("\"", first+1) + line = line[0:first] + line[second+1:] + while line.count("\'") % 2 == 0 and line.count("'") > 0: + first = line.find("\'") + second = line.find("\'", first+1) + line = line[0:first] + line[second+1:] + return line + +def line_has_comment_address(line, returnable={}): + """checks that a given line has a comment + with a valid address""" + #first set the bank/offset to nada + returnable["bank"] = None + returnable["offset"] = None + returnable["address"] = None + #only valid characters are 0-9A-F + valid = [str(x) for x in range(0,10)] + [chr(x) for x in range(97, 102+1)] + #check if there is a comment in this line + if ";" not in line: + return False + #first throw away anything in quotes + if (line.count("\"") % 2 == 0 and line.count("\"")!=0) \ + or (line.count("\'") % 2 == 0 and line.count("\'")!=0): + line = remove_quoted_text(line) + #check if there is still a comment in this line after quotes removed + if ";" not in line: + return False + #but even if there's a semicolon there must be later text + if line[-1] == ";": + return False + #and just a space doesn't count + if line[-2:] == "; ": + return False + #and multiple whitespace doesn't count either + line = line.rstrip(" ") + if line[-1] == ";": + return False + #there must be more content after the semicolon + if len(line)-1 == line.find(";"): + return False + #split it up into the main comment part + comment = line[line.find(";")+1:] + #don't want no leading whitespace + comment = comment.lstrip(" ").rstrip(" ") + #split up multi-token comments into single tokens + token = comment + if " " in comment: + #use the first token in the comment + token = comment.split(" ")[0] + if token in ["0x", "$", "x", ":"]: + return False + bank, offset = None, None + #process a token with a A:B format + if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A + #split up the token + bank_piece = token.split(":")[0].lower() + offset_piece = token.split(":")[1].lower() + #filter out blanks/duds + if bank_piece in ["$", "0x", "x"] \ + or offset_piece in ["$", "0x", "x"]: + return False + #they can't have both "$" and "x" + if "$" in bank_piece and "x" in bank_piece: + return False + if "$" in offset_piece and "x" in offset_piece: + return False + #process the bank piece + if "$" in bank_piece: + bank_piece = bank_piece.replace("$", "0x") + #check characters for validity? + for c in bank_piece.replace("x", ""): + if c not in valid: + return False + bank = int(bank_piece, 16) + #process the offset piece + if "$" in offset_piece: + offset_piece = offset_piece.replace("$", "0x") + #check characters for validity? + for c in offset_piece.replace("x", ""): + if c not in valid: + return False + offset = int(offset_piece, 16) + #filter out blanks/duds + elif token in ["$", "0x", "x"]: + return False + #can't have both "$" and "x" in the number + elif "$" in token and "x" in token: + return False + elif "x" in token and not "0x" in token: #it should be 0x + return False + elif "$" in token and not "x" in token: + token = token.replace("$", "0x") + offset = int(token, 16) + bank = calculate_bank(offset) + elif "0x" in token and not "$" in token: + offset = int(token, 16) + bank = calculate_bank(offset) + else: #might just be "1" at this point + token = token.lower() + #check if there are bad characters + for c in token: + if c not in valid: + return False + offset = int(token, 16) + bank = calculate_bank(offset) + if offset == None and bank == None: + return False + returnable["bank"] = bank + returnable["offset"] = offset + returnable["address"] = calculate_pointer(offset, bank=bank) + return True +def line_has_label(line): + """returns True if the line has an asm label""" + if not isinstance(line, str): + raise Exception, "can't check this type of object" + line = line.rstrip(" ").lstrip(" ") + line = remove_quoted_text(line) + if ";" in line: + line = line.split(";")[0] + if 0 <= len(line) <= 1: + return False + if ":" not in line: + return False + if line[0] == ";": + return False + if line[0] == "\"": + return False + if "::" in line: + return False + return True +def get_label_from_line(line): + """returns the label from the line""" + #check if the line has a label + if not line_has_label(line): + return None + #split up the line + label = line.split(":")[0] + return label +def find_labels_without_addresses(): + """scans the asm source and finds labels that are unmarked""" + without_addresses = [] + for (line_number, line) in enumerate(asm): + if line_has_label(line): + label = get_label_from_line(line) + if not line_has_comment_address(line): + without_addresses.append({"line_number": line_number, "line": line, "label": label}) + return without_addresses + +label_errors = "" +def get_labels_between(start_line_id, end_line_id, bank_id): + labels = [] + #label = { + # "line_number": 15, + # "bank": 32, + # "label": "PalletTownText1", + # "offset": 0x5315, + # "address": 0x75315, + #} + sublines = asm[start_line_id : end_line_id + 1] + for (current_line_offset, line) in enumerate(sublines): + #skip lines without labels + if not line_has_label(line): continue + #reset some variables + line_id = start_line_id + current_line_offset + line_label = get_label_from_line(line) + address = None + offset = None + #setup a place to store return values from line_has_comment_address + returnable = {} + #get the address from the comment + has_comment = line_has_comment_address(line, returnable=returnable) + #skip this line if it has no address in the comment + if not has_comment: continue + #parse data from line_has_comment_address + address = returnable["address"] + bank = returnable["bank"] + offset = returnable["offset"] + #dump all this info into a single structure + label = { + "line_number": line_id, + "bank": bank, + "label": line_label, + "offset": offset, + "address": address, + } + #store this structure + labels.append(label) + return labels + +def scan_for_predefined_labels(): + """looks through the asm file for labels at specific addresses, + this relies on the label having its address after. ex: + + ViridianCity_h: ; 0x18357 to 0x18384 (45 bytes) (bank=6) (id=1) + PalletTownText1: ; 4F96 0x18f96 + ViridianCityText1: ; 0x19102 + + It would be more productive to use rgbasm to spit out all label + addresses, but faster to write this script. rgbasm would be able + to grab all label addresses better than this script.. + """ + bank_intervals = {} + all_labels = [] + + #figure out line numbers for each bank + for bank_id in range(0x7F+1): + abbreviation = ("%.x" % (bank_id)).upper() + abbreviation_next = ("%.x" % (bank_id+1)).upper() + if bank_id == 0: + abbreviation = "0" + abbreviation_next = "1" + + start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line) + + if bank_id != 0x2c: + end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line) + else: + end_line_id = len(asm) - 1 + + print "bank" + abbreviation + " starts at " + str(start_line_id) + " to " + str(end_line_id) + + bank_intervals[bank_id] = { + "start": start_line_id, + "end": end_line_id, + } + for bank_id in bank_intervals.keys(): + bank_data = bank_intervals[bank_id] + + start_line_id = bank_data["start"] + end_line_id = bank_data["end"] + + labels = get_labels_between(start_line_id, end_line_id, bank_id) + #bank_intervals[bank_id]["labels"] = labels + all_labels.extend(labels) + + write_all_labels(all_labels) + return all_labels + #### generic testing #### class TestCram(unittest.TestCase): @@ -4615,6 +5072,69 @@ class TestAsmList(unittest.TestCase): self.assertEquals(len(base), asm.length()) self.assertEquals(len(base), len(list(asm))) self.assertEquals(len(asm), asm.length()) + def test_remove_quoted_text(self): + x = remove_quoted_text + self.assertEqual(x("hello world"), "hello world") + self.assertEqual(x("hello \"world\""), "hello ") + input = 'hello world "testing 123"' + self.assertNotEqual(x(input), input) + input = "hello world 'testing 123'" + self.assertNotEqual(x(input), input) + self.failIf("testing" in x(input)) + def test_line_has_comment_address(self): + x = line_has_comment_address + self.assertFalse(x("")) + self.assertFalse(x(";")) + self.assertFalse(x(";;;")) + self.assertFalse(x(":;")) + self.assertFalse(x(":;:")) + self.assertFalse(x(";:")) + self.assertFalse(x(" ")) + self.assertFalse(x("".join(" " * 5))) + self.assertFalse(x("".join(" " * 10))) + self.assertFalse(x("hello world")) + self.assertFalse(x("hello_world")) + self.assertFalse(x("hello_world:")) + self.assertFalse(x("hello_world:;")) + self.assertFalse(x("hello_world: ;")) + self.assertFalse(x("hello_world: ; ")) + self.assertFalse(x("hello_world: ;" + "".join(" " * 5))) + self.assertFalse(x("hello_world: ;" + "".join(" " * 10))) + self.assertTrue(x(";1")) + self.assertTrue(x(";F")) + self.assertTrue(x(";$00FF")) + self.assertTrue(x(";0x00FF")) + self.assertTrue(x("; 0x00FF")) + self.assertTrue(x(";$3:$300")) + self.assertTrue(x(";0x3:$300")) + self.assertTrue(x(";$3:0x300")) + self.assertTrue(x(";3:300")) + self.assertTrue(x(";3:FFAA")) + self.assertFalse(x('hello world "how are you today;0x1"')) + self.assertTrue(x('hello world "how are you today:0x1";1')) + def test_line_has_label(self): + x = line_has_label + self.assertTrue(x("hi:")) + self.assertTrue(x("Hello: ")) + self.assertTrue(x("MyLabel: ; test xyz")) + self.assertFalse(x(":")) + self.assertFalse(x(";HelloWorld:")) + self.assertFalse(x("::::")) + self.assertFalse(x(":;:;:;:::")) + def test_get_label_from_line(self): + x = get_label_from_line + self.assertEqual(x("HelloWorld: "), "HelloWorld") + self.assertEqual(x("HiWorld:"), "HiWorld") + self.assertEqual(x("HiWorld"), None) + def test_find_labels_without_addresses(self): + global asm + asm = ["hello_world: ; 0x1", "hello_world2: ;"] + labels = find_labels_without_addresses() + self.failUnless(labels[0]["label"] == "hello_world2") + asm = ["hello world: ;1", "hello_world: ;2"] + labels = find_labels_without_addresses() + self.failUnless(len(labels) == 0) + asm = None class TestMapParsing(unittest.TestCase): #def test_parse_warp_bytes(self): # pass #or raise NotImplementedError, bryan_message