/* $Id: PKCS1_PSS.java,v 1.1 2011/05/04 22:37:58 willuhn Exp $
This file is part of CryptAlgs4Java
Copyright (C) 2001-2010 Stefan Palme
CryptAlgs4Java is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
CryptAlgs4Java is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
package org.kapott.cryptalgs;
import java.io.ByteArrayOutputStream;
import java.math.BigInteger;
import java.security.InvalidAlgorithmParameterException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SignatureException;
import java.security.SignatureSpi;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Arrays;
public class PKCS1_PSS
extends SignatureSpi
{
private RSAPublicKey pubKey;
private PrivateKey privKey;
private SignatureParamSpec param;
private ByteArrayOutputStream plainmsg;
// ----- some interface stuff ---------------------------------------------
@Override
@Deprecated
protected void engineSetParameter(String param1, Object value)
{
// do nothing
}
@Override
protected void engineSetParameter(AlgorithmParameterSpec param1)
throws InvalidAlgorithmParameterException
{
if (param1 instanceof SignatureParamSpec)
this.param=(SignatureParamSpec)(param1);
else {
throw new InvalidAlgorithmParameterException();
}
}
@Override
@Deprecated
protected Object engineGetParameter(String parameter)
{
return null;
}
public static MessageDigest getMessageDigest(SignatureParamSpec spec)
{
MessageDigest result;
try {
String provider=spec.getProvider();
if (provider!=null) {
result=MessageDigest.getInstance(spec.getHashAlg(),provider);
} else {
result=MessageDigest.getInstance(spec.getHashAlg());
}
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
} catch (NoSuchProviderException e) {
throw new RuntimeException(e);
}
return result;
}
@Override
protected void engineInitSign(PrivateKey privateKey)
{
this.privKey=privateKey;
this.plainmsg=new ByteArrayOutputStream();
}
@Override
protected void engineInitVerify(PublicKey publicKey)
{
this.pubKey=(RSAPublicKey)publicKey;
this.plainmsg=new ByteArrayOutputStream();
}
@Override
protected void engineUpdate(byte b)
{
this.plainmsg.write(b);
}
@Override
protected void engineUpdate(byte[] b,int offset,int length)
{
for (int i=0;i<length;i++) {
engineUpdate(b[offset+i]);
}
}
@Override
protected int engineSign(byte[] output,int offset,int len)
throws SignatureException
{
byte[] sig=engineSign();
if (offset+len>output.length)
throw new SignatureException("output result too large for buffer");
System.arraycopy(sig,0,output,offset,sig.length);
return sig.length;
}
@Override
protected byte[] engineSign()
{
return pss_sign(this.privKey, this.plainmsg.toByteArray());
}
@Override
protected boolean engineVerify(byte[] sig)
{
return pss_verify(this.pubKey, this.plainmsg.toByteArray(), sig);
}
// --- stuff from the PKCS#1-PSS specification ---------------------------
private static byte[] i2os(BigInteger x, int outLen)
{
byte[] bytes=x.toByteArray();
if (bytes.length>outLen) {
// created output len does not fit into outLen
// maybe this are only leading zeroes, so we will check this
for (int i=0; i<(bytes.length-outLen); i++) {
if (bytes[i]!=0) {
throw new RuntimeException("value too large");
}
}
// ok, now remove leading zeroes
byte[] out=new byte[outLen];
System.arraycopy(bytes, bytes.length-outLen, out, 0, outLen);
bytes = out;
} else if (bytes.length<outLen) {
// created output is too small, so create leading zeroes
byte[] out=new byte[outLen];
System.arraycopy(bytes, 0, out, outLen-bytes.length, bytes.length);
bytes = out;
}
return bytes;
}
private static BigInteger os2i(byte[] bytes)
{
return new BigInteger(+1, bytes);
}
private static BigInteger sp1(PrivateKey key, BigInteger m)
{
BigInteger result;
if (key instanceof RSAPrivateKey) {
BigInteger d=((RSAPrivateKey)key).getPrivateExponent();
BigInteger n=((RSAPrivateKey)key).getModulus();
result = m.modPow(d,n);
} else {
RSAPrivateCrtKey2 key2=(RSAPrivateCrtKey2)key;
BigInteger p=key2.getP();
BigInteger q=key2.getQ();
BigInteger dP=key2.getdP();
BigInteger dQ=key2.getdQ();
BigInteger qInv=key2.getQInv();
BigInteger s1 = m.modPow(dP,p);
BigInteger s2 = m.modPow(dQ,q);
BigInteger h = s1.subtract(s2).multiply(qInv).mod(p);
result = s2.add(q.multiply(h));
}
return result;
}
private static BigInteger vp1(RSAPublicKey key, BigInteger s)
{
BigInteger e=key.getPublicExponent();
BigInteger n=key.getModulus();
BigInteger m=s.modPow(e,n);
return m;
}
private static byte[] concat(byte[] x1, byte[] x2)
{
byte[] result=new byte[x1.length+x2.length];
System.arraycopy(x1,0, result,0, x1.length);
System.arraycopy(x2,0, result,x1.length, x2.length);
return result;
}
private static byte[] hash(SignatureParamSpec spec, byte[] data)
{
MessageDigest dig=getMessageDigest(spec);
dig.reset();
return dig.digest(data);
}
private static byte[] mgf1(SignatureParamSpec spec, byte[] mgfSeed, int maskLen)
{
MessageDigest dig=getMessageDigest(spec);
int hLen=dig.getDigestLength();
byte[] T=new byte[0];
for (int i=0; i<Math.ceil(maskLen/(double)hLen); i++) {
byte[] c=i2os(new BigInteger(Integer.toString(i)), 4);
T = concat(T, hash(spec, concat(mgfSeed,c)));
}
byte[] result=new byte[maskLen];
System.arraycopy(T,0, result,0, maskLen);
return result;
}
private static byte[] random_os(int len)
{
byte[] result=new byte[len];
for (int i=0; i<len; i++) {
result[i]=(byte)(256*Math.random());
}
return result;
}
private static byte[] xor_os(byte[] a1, byte[] a2)
{
if (a1.length!=a2.length) {
throw new RuntimeException("a1.len != a2.len");
}
byte[] result=new byte[a1.length];
for (int i=0; i<result.length; i++) {
result[i] = (byte)(a1[i] ^ a2[i]);
}
return result;
}
public static byte[] emsa_pss_encode(SignatureParamSpec spec, byte[] msg, int emBits)
{
int emLen=emBits>>3;
if ((emBits&7) != 0) {
emLen++;
}
// System.out.println("message: "+Utils.bytes2String(msg));
byte[] mHash = hash(spec, msg);
// System.out.println("mHash: "+Utils.bytes2String(mHash));
MessageDigest dig=getMessageDigest(spec);
int hLen = dig.getDigestLength();
int sLen = hLen;
byte[] salt = random_os(sLen);
byte[] zeroes = new byte[8];
byte[] m2 = concat(concat(zeroes,mHash),salt);
// System.out.println("M': "+Utils.bytes2String(m2));
byte[] H = hash(spec, m2);
// System.out.println("H: "+Utils.bytes2String(H));
byte[] PS = new byte[emLen-sLen-hLen-2];
byte[] DB = concat(concat(PS, new byte[] {0x01}), salt);
// System.out.println("DB: "+Utils.bytes2String(DB));
byte[] dbMask = mgf1(spec, H, emLen-hLen-1);
// System.out.println("dbMask: "+Utils.bytes2String(dbMask));
byte[] maskedDB = xor_os(DB, dbMask);
// System.out.println("maskedDB: "+Utils.bytes2String(maskedDB));
// set leftmost X bits in maskedDB to zero
int tooMuchBits=(emLen<<3)-emBits;
byte mask=(byte)(0xFF>>>tooMuchBits);
maskedDB[0] &= mask;
byte[] EM = concat(concat(maskedDB,H), new byte[] {(byte)0xBC});
// System.out.println("EM: "+Utils.bytes2String(EM));
return EM;
}
public static boolean emsa_pss_verify(SignatureParamSpec spec, byte[] msg, byte[] EM, int emBits)
{
int emLen=emBits>>3;
if ((emBits&7) != 0) {
emLen++;
}
byte[] mHash = hash(spec, msg);
// System.out.println("mHash: "+Utils.bytes2String(mHash));
MessageDigest dig=getMessageDigest(spec);
int hLen = dig.getDigestLength();
// System.out.println("hLen: "+hLen);
int sLen = hLen;
if (EM[EM.length-1]!=(byte)0xBC) {
// System.out.println("no BC at the end");
return false;
}
byte[] maskedDB = new byte[emLen-hLen-1];
byte[] H = new byte[hLen];
System.arraycopy(EM,0, maskedDB,0, emLen-hLen-1);
System.arraycopy(EM,emLen-hLen-1, H,0, hLen);
// TODO: verify if first X bits of maskedDB are zero
byte[] dbMask = mgf1(spec, H, emLen-hLen-1);
byte[] DB = xor_os(maskedDB, dbMask);
// set leftmost X bits of DB to zero
int tooMuchBits=(emLen<<3)-emBits;
byte mask=(byte)(0xFF>>>tooMuchBits);
DB[0] &= mask;
// TODO: another consistency check
byte[] salt = new byte[sLen];
System.arraycopy(DB,DB.length-sLen, salt,0, sLen);
byte[] zeroes = new byte[8];
byte[] m2 = concat(concat(zeroes,mHash),salt);
byte[] H2 = hash(spec, m2);
return Arrays.equals(H,H2);
}
public static int calculateEMBitLen(BigInteger modulus)
{
return modulus.bitLength()-1;
}
private byte[] pss_sign(PrivateKey key, byte[] msg)
{
// Modulus holen, weil dessen Bitlänge benötigt wird
BigInteger bModulus;
if (key instanceof RSAPrivateKey) {
bModulus=((RSAPrivateKey)key).getModulus();
} else {
bModulus=((RSAPrivateCrtKey2)key).getP().multiply(((RSAPrivateCrtKey2)key).getQ());
}
int modBits = bModulus.bitLength();
int k = modBits>>3;
if ((modBits&7) != 0) {
k++;
}
byte[] EM = emsa_pss_encode(this.param, msg, modBits-1);
BigInteger m = os2i(EM);
BigInteger s = sp1(key, m);
byte[] S = i2os(s, k);
// System.out.println("S: "+Utils.bytes2String(S));
return S;
}
private boolean pss_verify(RSAPublicKey key, byte[] msg, byte[] S)
{
BigInteger s = os2i(S);
BigInteger m = vp1(key, s);
BigInteger n = key.getModulus();
int emBits = n.bitLength()-1;
int emLen = emBits>>3;
if ((emBits&7)!=0) {
emLen++;
}
byte[] EM = i2os(m, emLen);
// System.out.println("EM: "+Utils.bytes2String(EM));
return emsa_pss_verify(this.param, msg, EM, emBits);
}
}