/*
 * Decompiled with CFR 0.152.
 */
package com.google.crypto.tink.subtle;

import com.google.crypto.tink.PublicKeySign;
import com.google.crypto.tink.subtle.EngineFactory;
import com.google.crypto.tink.subtle.Enums;
import com.google.crypto.tink.subtle.Random;
import com.google.crypto.tink.subtle.SubtleUtil;
import com.google.crypto.tink.subtle.Validators;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.RSAPublicKeySpec;
import javax.crypto.Cipher;

public final class RsaSsaPssSignJce
implements PublicKeySign {
    private final RSAPrivateCrtKey privateKey;
    private final RSAPublicKey publicKey;
    private final Enums.HashType sigHash;
    private final Enums.HashType mgf1Hash;
    private final int saltLength;
    private static final String RAW_RSA_ALGORITHM = "RSA/ECB/NOPADDING";

    public RsaSsaPssSignJce(RSAPrivateCrtKey priv, Enums.HashType sigHash, Enums.HashType mgf1Hash, int saltLength) throws GeneralSecurityException {
        Validators.validateSignatureHash(sigHash);
        Validators.validateRsaModulusSize(priv.getModulus().bitLength());
        Validators.validateRsaPublicExponent(priv.getPublicExponent());
        this.privateKey = priv;
        KeyFactory kf2 = EngineFactory.KEY_FACTORY.getInstance("RSA");
        this.publicKey = (RSAPublicKey)kf2.generatePublic(new RSAPublicKeySpec(priv.getModulus(), priv.getPublicExponent()));
        this.sigHash = sigHash;
        this.mgf1Hash = mgf1Hash;
        this.saltLength = saltLength;
    }

    @Override
    public byte[] sign(byte[] data) throws GeneralSecurityException {
        int modBits = this.publicKey.getModulus().bitLength();
        byte[] em2 = this.emsaPssEncode(data, modBits - 1);
        return this.rsasp1(em2);
    }

    private byte[] rsasp1(byte[] m3) throws GeneralSecurityException {
        Cipher decryptCipher = EngineFactory.CIPHER.getInstance(RAW_RSA_ALGORITHM);
        decryptCipher.init(2, this.privateKey);
        byte[] c2 = decryptCipher.doFinal(m3);
        Cipher encryptCipher = EngineFactory.CIPHER.getInstance(RAW_RSA_ALGORITHM);
        encryptCipher.init(1, this.publicKey);
        byte[] m0 = encryptCipher.doFinal(c2);
        if (!new BigInteger(1, m3).equals(new BigInteger(1, m0))) {
            throw new RuntimeException("Security bug: RSA signature computation error");
        }
        return c2;
    }

    private byte[] emsaPssEncode(byte[] m3, int emBits) throws GeneralSecurityException {
        int i2;
        Validators.validateSignatureHash(this.sigHash);
        MessageDigest digest = EngineFactory.MESSAGE_DIGEST.getInstance(SubtleUtil.toDigestAlgo(this.sigHash));
        byte[] mHash = digest.digest(m3);
        int hLen = digest.getDigestLength();
        int emLen = (emBits - 1) / 8 + 1;
        if (emLen < hLen + this.saltLength + 2) {
            throw new GeneralSecurityException("encoding error");
        }
        byte[] salt = Random.randBytes(this.saltLength);
        byte[] mPrime = new byte[8 + hLen + this.saltLength];
        System.arraycopy(mHash, 0, mPrime, 8, hLen);
        System.arraycopy(salt, 0, mPrime, 8 + hLen, salt.length);
        byte[] h2 = digest.digest(mPrime);
        byte[] db2 = new byte[emLen - hLen - 1];
        db2[emLen - this.saltLength - hLen - 2] = 1;
        System.arraycopy(salt, 0, db2, emLen - this.saltLength - hLen - 1, salt.length);
        byte[] dbMask = SubtleUtil.mgf1(h2, emLen - hLen - 1, this.mgf1Hash);
        byte[] maskedDb = new byte[emLen - hLen - 1];
        for (i2 = 0; i2 < maskedDb.length; ++i2) {
            maskedDb[i2] = (byte)(db2[i2] ^ dbMask[i2]);
        }
        i2 = 0;
        while ((long)i2 < (long)emLen * 8L - (long)emBits) {
            int bytePos = i2 / 8;
            int bitPos = 7 - i2 % 8;
            maskedDb[bytePos] = (byte)(maskedDb[bytePos] & ~(1 << bitPos));
            ++i2;
        }
        byte[] em2 = new byte[maskedDb.length + hLen + 1];
        System.arraycopy(maskedDb, 0, em2, 0, maskedDb.length);
        System.arraycopy(h2, 0, em2, maskedDb.length, h2.length);
        em2[maskedDb.length + hLen] = -68;
        return em2;
    }
}

