import os
import angr
import angrop  # pylint: disable=unused-import
import pickle

import logging
l = logging.getLogger("angrop.tests.test_rop")

public_bin_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../binaries/tests')
test_data_location = os.path.join(public_bin_location, "..", "tests_data", "angrop_gadgets_cache")


"""
Suggestions on how to debug angr changes that break angrop.

If the gadget is completely missing after your changes. Pick the address that didn't work and run the following.
The logging should say why the gadget was discarded.

rop = p.analyses.ROP()
angrop.gadget_analyzer.l.setLevel("DEBUG")
rop._gadget_analyzer.analyze_gadget(addr)

If a gadget is missing memory reads / memory writes / memory changes, the actions are probably missing.
Memory changes require a read action followed by a write action to the same address.
"""


def assert_mem_access_equal(m1, m2):
    assert set(m1.addr_dependencies) ==set(m2.addr_dependencies)
    assert set(m1.addr_controllers) == set(m2.addr_controllers)
    assert set(m1.data_dependencies) == set(m2.data_dependencies)
    assert set(m1.data_controllers) == set(m2.data_controllers)
    assert m1.addr_constant == m2.addr_constant
    assert m1.data_constant == m2.data_constant
    assert m1.addr_size == m2.addr_size
    assert m1.data_size == m2.data_size


def assert_gadgets_equal(known_gadget, test_gadget):
    assert known_gadget.addr == test_gadget.addr
    assert known_gadget.changed_regs == test_gadget.changed_regs
    assert known_gadget.popped_regs == test_gadget.popped_regs
    assert known_gadget.reg_dependencies == test_gadget.reg_dependencies
    assert known_gadget.reg_controllers == test_gadget.reg_controllers
    assert known_gadget.stack_change == test_gadget.stack_change
    if hasattr(known_gadget, "makes_syscall"):
        assert known_gadget.makes_syscall == test_gadget.makes_syscall

    assert len(known_gadget.mem_reads) == len(test_gadget.mem_reads)
    for m1, m2 in zip(known_gadget.mem_reads, test_gadget.mem_reads):
        assert_mem_access_equal(m1, m2)
    assert len(known_gadget.mem_writes) == len(test_gadget.mem_writes)
    for m1, m2 in zip(known_gadget.mem_writes, test_gadget.mem_writes):
        assert_mem_access_equal(m1, m2)
    assert len(known_gadget.mem_changes) == len(test_gadget.mem_changes)
    for m1, m2 in zip(known_gadget.mem_changes, test_gadget.mem_changes):
        assert_mem_access_equal(m1, m2)

    assert known_gadget.addr == test_gadget.addr
    assert known_gadget.changed_regs == test_gadget.changed_regs


def compare_gadgets(test_gadgets, known_gadgets):
    test_gadgets = sorted(test_gadgets, key=lambda x: x.addr)
    known_gadgets = sorted(known_gadgets, key=lambda x: x.addr)

    # we allow new gadgets to be found, but only check the correctness of those that were there in the known_gadgets
    # so filter new gadgets found
    expected_addrs = set(g.addr for g in known_gadgets)
    test_gadgets = [g for g in test_gadgets if g.addr in expected_addrs]

    # check that each of the expected gadget addrs was found as a gadget
    # if it wasn't the best way to debug is to run:
    # angrop.gadget_analyzer.l.setLevel("DEBUG"); rop._gadget_analyzer.analyze_gadget(addr)
    test_gadget_dict = {g.addr: g for g in test_gadgets}

    found_addrs = set(g.addr for g in test_gadgets)
    for g in known_gadgets:
        assert g.addr in found_addrs

    # So now we should have
    assert len(test_gadgets) == len(known_gadgets)

    # check gadgets
    for g in known_gadgets:
        assert_gadgets_equal(g, test_gadget_dict[g.addr])


def execute_chain(project, chain):
    s = project.factory.blank_state()
    s.memory.store(s.regs.sp, chain.payload_str() + b"AAAAAAAAA")
    s.ip = s.stack_pop()
    p = project.factory.simulation_manager(s)
    goal_addr = 0x4141414141414141 % (1 << project.arch.bits)
    while p.one_active.addr != goal_addr:
        p.step()
        assert len(p.active) == 1

    return p.one_active


def test_rop_x86_64():
    b = angr.Project(os.path.join(public_bin_location, "x86_64/datadep_test"), auto_load_libs=False)
    rop = b.analyses.ROP()
    rop.find_gadgets_single_threaded(show_progress=False)

    cache_path = os.path.join(test_data_location, "datadep_test_gadgets")
    if not os.path.exists(cache_path):
        rop.save_gadgets(cache_path)

    # check gadgets
    tup = pickle.load(open(cache_path, "rb"))
    compare_gadgets(rop._all_gadgets, tup[0])

    # test creating a rop chain
    chain = rop.set_regs(rbp=0x1212, rbx=0x1234567890123456)
    # smallest possible chain
    assert chain.payload_len == 24
    # chain is correct
    result_state = execute_chain(b, chain)
    assert result_state.solver.eval(result_state.regs.rbp) == 0x1212
    assert result_state.solver.eval(result_state.regs.rbx) == 0x1234567890123456

    # test setting the filler value
    rop.set_roparg_filler(0x4141414141414141)
    chain = rop.set_regs(rbx=0x121212)
    assert chain._concretize_chain_values()[2][0] == 0x4141414141414141


def test_rop_i386_cgc():
    b = angr.Project(os.path.join(public_bin_location, "cgc/sc1_0b32aa01_01"), auto_load_libs=False)
    rop = b.analyses.ROP()
    rop.find_gadgets_single_threaded(show_progress=False)

    cache_path = os.path.join(test_data_location, "0b32aa01_01_gadgets")
    if not os.path.exists(cache_path):
        rop.save_gadgets(cache_path)

    # check gadgets
    tup = pickle.load(open(os.path.join(test_data_location, "0b32aa01_01_gadgets"), "rb"))
    compare_gadgets(rop._all_gadgets, tup[0])

    # test creating a rop chain
    chain = rop.set_regs(ebx=0x98765432, ecx=0x12345678)
    # smallest possible chain
    assert chain.payload_len == 12
    # chain is correct
    result_state = execute_chain(b, chain)
    assert result_state.solver.eval(result_state.regs.ebx) == 0x98765432
    assert result_state.solver.eval(result_state.regs.ecx) == 0x12345678

    # test memwrite chain
    chain = rop.write_to_mem(0x41414141, b"ABCDEFGH")
    result_state = execute_chain(b, chain)
    assert result_state.solver.eval(result_state.memory.load(0x41414141, 8), cast_to=bytes) == b"ABCDEFGH"

def test_rop_arm():
    b = angr.Project(os.path.join(public_bin_location, "armel/manysum"), load_options={"auto_load_libs": False})
    rop = b.analyses.ROP()
    rop.find_gadgets_single_threaded(show_progress=False)

    cache_path = os.path.join(test_data_location, "arm_manysum_test_gadgets")
    if not os.path.exists(cache_path):
        rop.save_gadgets(cache_path)

    # check gadgets
    tup = pickle.load(open(os.path.join(test_data_location, "arm_manysum_test_gadgets"), "rb"))
    compare_gadgets(rop._all_gadgets, tup[0])

    # test creating a rop chain
    chain = rop.set_regs(r11=0x99887766)
    # smallest possible chain
    assert chain.payload_len == 8
    # correct chains, using a more complicated chain here
    chain = rop.set_regs(r4=0x99887766, r9=0x44556677, r11=0x11223344)
    result_state = execute_chain(b, chain)
    assert result_state.solver.eval(result_state.regs.r4) == 0x99887766
    assert result_state.solver.eval(result_state.regs.r9) == 0x44556677
    assert result_state.solver.eval(result_state.regs.r11) == 0x11223344

    # test memwrite chain
    chain = rop.write_to_mem(0x41414141, b"ABCDEFGH")
    result_state = execute_chain(b, chain)
    assert result_state.solver.eval(result_state.memory.load(0x41414141, 8), cast_to=bytes) == b"ABCDEFGH"

def test_roptest_x86_64():
    p = angr.Project(os.path.join(public_bin_location, "x86_64/roptest"), auto_load_libs=False)
    r = p.analyses.ROP(only_check_near_rets=False)
    r.find_gadgets_single_threaded(show_progress=False)
    c = r.execve(path=b"/bin/sh")

    # verifying this is a giant pain, partially because the binary is so tiny, and there's no code beyond the syscall
    assert len(c._gadgets) == 8

    # verify the chain is valid
    chain_addrs = [ g.addr for g in c._gadgets ]
    assert chain_addrs[1] in [0x4000b2, 0x4000bd]
    assert chain_addrs[5] in [0x4000b2, 0x4000bd]
    chain_addrs[1] = 0x4000b2
    chain_addrs[5] = 0x4000b2
    assert chain_addrs == [ 0x4000b0, 0x4000b2, 0x4000b4, 0x4000b0, 0x4000bb, 0x4000b2, 0x4000bf, 0x4000c1 ]

def test_roptest_mips():
    proj = angr.Project(os.path.join(public_bin_location, "mipsel/darpa_ping"), auto_load_libs=False)
    rop = proj.analyses.ROP()
    rop.find_gadgets_single_threaded(show_progress=False)

    chain = rop.set_regs(s0=0x41414141, s1=0x42424242, v0=0x43434343)
    result_state = execute_chain(proj, chain)
    assert result_state.solver.eval(result_state.regs.s0) == 0x41414141
    assert result_state.solver.eval(result_state.regs.s1) == 0x42424242
    assert result_state.solver.eval(result_state.regs.v0) == 0x43434343


def run_all():
    functions = globals()
    all_functions = dict([x for x in functions.items() if x[0].startswith('test_')])
    for f in sorted(all_functions.keys()):
        if hasattr(all_functions[f], '__call__'):
            all_functions[f]()


if __name__ == "__main__":
    logging.getLogger("angrop.rop").setLevel(logging.DEBUG)

    import sys
    if len(sys.argv) > 1:
        globals()['test_' + sys.argv[1]]()
    else:
        run_all()
