'''
A collection of tools to assist in analyzing encrypted data 
through chosen ciphertext attacks.

Copyright (C) 2012-2013 Virtual Security Research, LLC
Copyright (C) 2014-2016 Blindspot Security LLC
Author: Timothy D. Morgan

 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU Lesser General Public License, version 3,
 as published by the Free Software Foundation.

 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.

 You should have received a copy of the GNU General Public License
 along with this program.  If not, see <http://www.gnu.org/licenses/>.
'''

import sys
import threading
import struct
import queue
import hashlib
import codecs


def escape_handler(error):
    ret_val = ''
    for b in error.object[error.start:error.end]:
        ret_val += "\\x%.2X" % b

    return (ret_val,error.end)

codecs.register_error('decode_escape',escape_handler)


# Wish Python had a better function for this that escaped more characters
_html_escape_table = {
    "&":  "&amp;",
    '"':  "&quot;",
    "'":  "&apos;",
    ">":  "&gt;",
    "<":  "&lt;",
    "\n": "&#x0a;",
    "\r": "&#x0d;",
    }

def _html_escape(text):
    return "".join(_html_escape_table.get(c,c) for c in text)


class ProbeResults:
    '''TODO
    '''
    _values = None
    _raw_table = None #indexes are byte offset, then XORed value
    _messages = None
    _html_header = """<head>
<script>
function displayMessage(id) 
{
  alert(document.getElementById(id).value);
}
</script>
<style>
td
{
  border-style: solid;
  border-width: medium;
  border-color: #FFFFFF;
  border-spacing: 0px;
  min-width: 100px;
  max-width: 100px;
  word-wrap: break-word;
}
</style></head>"""

    
    def __init__(self, ct_length, values):
        self._ct_length = ct_length
        self._values = values
        self._raw_table = {}
        self._messages = {}
        return

    def _generate_colors(self, s):
        if isinstance(s, str):
            s = s.encode('utf-8')
        base=bytes(hashlib.md5(s).digest()[0:6])
        color1 = "#%.2X%.2X%.2X" % tuple(base[:3])
        color2 = "#%.2X%.2X%.2X" % tuple(base[3:])

        return color1,color2


    def toHTML(self):
        maxlen = 20
        ret_val = self._html_header
        ret_val += '<table><tr><td>&nbsp;&nbsp;&nbsp;OFFSET<br /><br />VALUE</td>'

        for offset in self._raw_table.keys():
            ret_val += '<td>%d<br /><br /></td>' % offset
        ret_val += '</tr>'

        for v in self._values:
            ret_val += '<tr><td>0x%.2X</td>' % v
            for offset in range(0,self._ct_length):
                message = self._raw_table[offset][v]
                bg,fg = self._generate_colors(message)
                if not isinstance(message, str):
                    message = message.decode('utf-8', 'decode_escape')

                truncated = message[0:maxlen]
                if len(message) > maxlen:
                    truncated += '...'
                msg_id = 'cell_%.2X_%.2X' % (offset, v)
                ret_val += ('''<td style="background-color:%s; border-color:%s" onclick="displayMessage('%s')">'''
                            '''<input type="hidden" id="%s" value="%s" />%s</td>\n''')\
                            % (bg,fg, msg_id, msg_id, _html_escape(message), _html_escape(truncated))
            ret_val += '</tr>'
            
        ret_val += '</table>'

        return ret_val


def probe_bytes(checker, ciphertext, values, max_threads=1):
    '''For each offset in the ciphertext, XORs each of the values with
    it and sends it to the checker to determine what kind of response or
    error message was generated.

    Arguments:
    checker -- A function which sends a specified ciphertext to the targeted
               application and returns a string describing the kind of response
               that was encountered.  This function should be thread-safe when
               max_threads > 1.

               This function should implement the prototype:
                 def myChecker(ciphertext): ...

               The function should return strings that are relevant to
               the kind of overall response generated by the targeted
               system or application.  For instance, if detailed error
               messages are returned, then the important parts of those
               errors should be returned.  If error messages are not
               returned in some cases, then simple tokens that describe
               the behavior of the response should suffice.  For
               instance, if in some cases the application returns a
               generic HTTP 500 error, in other cases it drops the TCP
               connection, and still in other cases it doesn't return an
               error, then the checker function could return "500",
               "dropped", and "success" respectively for those cases.

    ciphertext -- A ciphertext buffer (bytes/bytearray) that will be repeatedly
               modified and tested using the checker function.

    values --  A sequence of integers in the range [0..255].  These values
               will be XORed with each byte in the ciphertext and tested, one
               after another.  To make a single change to each byte in the 
               ciphertext, provide something like [1].  To flip every bit 
               in the entire ciphertext individually, supply: [1,2,4,8,16,32,64,128]

    max_threads -- The maximum number of threads to run in parallel while 
               testing modified ciphertexts.
    '''
    if max_threads < 1:
        return None

    ciphertext = bytearray(ciphertext)
    values = bytearray(values)

    # XXX: Improve threading model
    #      Instead of forking threads and joining them for each byte, 
    #      Generate all ciphertext variants up front, putting them in
    #      a jobs queue, and then have persistent threads pull from 
    #      the jobs queue  (or use a generator, rather than a queue)
    ret_val = ProbeResults(len(ciphertext), values)
    num_threads = min(len(values),max_threads)
    threads = []
    for j in range(0,len(ciphertext)):
        prefix = ciphertext[0:j]
        target = ciphertext[j]
        suffix = ciphertext[j+1:]
        results = queue.Queue()
        for i in range(0,num_threads):
            subset = [values[s] for s in range(i,len(values),num_threads)]
            t = threading.Thread(target=probe_worker, 
                                 args=(checker, prefix, suffix, target,
                                       subset, results))
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

        # XXX: add functions to ProbeResults class to add results here, 
        #      rather than accessing members directly.
        ret_val._raw_table[j] = {}
        while not results.empty():
            ret_val._raw_table[j].update(results.get())

    return ret_val


def probe_worker(checker, prefix, suffix, target, value_subset, results):
    for v in value_subset:
        results.put({v:checker(prefix+bytearray((v^target,))+suffix)})
