package com.trilead.ssh2.signature; import com.trilead.ssh2.IOWarningException; import com.trilead.ssh2.crypto.CertificateDecoder; import com.trilead.ssh2.crypto.PEMStructure; import com.trilead.ssh2.crypto.SimpleDERReader; import com.trilead.ssh2.packets.TypesReader; import com.trilead.ssh2.packets.TypesWriter; import java.io.IOException; import java.math.BigInteger; import java.security.GeneralSecurityException; import java.security.KeyFactory; import java.security.KeyPair; import java.security.PrivateKey; import java.security.PublicKey; import java.security.interfaces.RSAPublicKey; import java.security.interfaces.RSAPrivateKey; import java.security.spec.RSAPrivateCrtKeySpec; import java.security.spec.RSAPrivateKeySpec; import java.security.spec.RSAPublicKeySpec; import java.util.Arrays; import java.util.List; /** * @author Michael Clarke */ public class RSAKeyAlgorithm extends KeyAlgorithm<RSAPublicKey, RSAPrivateKey> { public RSAKeyAlgorithm() { super("SHA1WithRSA", "ssh-rsa", RSAPrivateKey.class); } @Override public byte[] encodeSignature(byte[] signature) throws IOException { final TypesWriter tw = new TypesWriter(); tw.writeString(getKeyFormat()); /* S is NOT an MPINT. "The value for 'rsa_signature_blob' is encoded as a string * containing s (which is an integer, without lengths or padding, unsigned and in * network byte order)." */ /* Remove first zero sign byte, if present */ if ((signature.length > 1) && (signature[0] == 0x00)) { tw.writeString(signature, 1, signature.length - 1); } else { tw.writeString(signature, 0, signature.length); } return tw.getBytes(); } @Override public byte[] decodeSignature(byte[] encodedSignature) throws IOException { final TypesReader tr = new TypesReader(encodedSignature); final String sig_format = tr.readString(); if (!sig_format.equals(getKeyFormat())) { throw new IOException("Peer sent wrong signature format"); } /* S is NOT an MPINT. "The value for 'rsa_signature_blob' is encoded as a string * containing s (which is an integer, without lengths or padding, unsigned and in * network byte order)." See also below. */ final byte[] s = tr.readByteString(); if (s.length == 0) { throw new IOException("Error in RSA signature, S is empty."); } if (tr.remain() != 0) { throw new IOException("Padding in RSA signature!"); } return s; } @Override public byte[] encodePublicKey(RSAPublicKey publicKey) throws IOException { final TypesWriter tw = new TypesWriter(); tw.writeString(getKeyFormat()); tw.writeMPInt(publicKey.getPublicExponent()); tw.writeMPInt(publicKey.getModulus()); return tw.getBytes(); } @Override public RSAPublicKey decodePublicKey(byte[] encodedPublicKey) throws IOException { final TypesReader tr = new TypesReader(encodedPublicKey); final String key_format = tr.readString(); if (!key_format.equals(getKeyFormat())) { throw new IOWarningException("Unsupported key format found '" + key_format + "' while expecting " + getKeyFormat()); } final BigInteger e = tr.readMPINT(); final BigInteger n = tr.readMPINT(); if (tr.remain() != 0) { throw new IOException("Padding in RSA public key!"); } try { final KeyFactory generator = KeyFactory.getInstance("RSA"); return (RSAPublicKey) generator.generatePublic(new RSAPublicKeySpec(n, e)); } catch (GeneralSecurityException ex) { throw new IOException("Could not generate RSA key", ex); } } @Override public List<CertificateDecoder> getCertificateDecoders() { return Arrays.asList(new RSACertificateDecoder(), new OpenSshCertificateDecoder("ssh-rsa") { @Override KeyPair generateKeyPair(TypesReader typesReader) throws GeneralSecurityException, IOException { BigInteger n = typesReader.readMPINT(); BigInteger e = typesReader.readMPINT(); BigInteger d = typesReader.readMPINT(); BigInteger c = typesReader.readMPINT(); BigInteger p = typesReader.readMPINT(); RSAPublicKeySpec publicKeySpec = new RSAPublicKeySpec(n, e); RSAPrivateKeySpec privateKeySpec; if (null == p || null == c) { privateKeySpec = new RSAPrivateKeySpec(n, d); } else { BigInteger q = c.modInverse(p); BigInteger pE = d.mod(p.subtract(BigInteger.ONE)); BigInteger qE = d.mod(q.subtract(BigInteger.ONE)); privateKeySpec = new RSAPrivateCrtKeySpec(n, e, d, p, q, pE, qE, c); } KeyFactory factory = KeyFactory.getInstance("RSA"); return new KeyPair(factory.generatePublic(publicKeySpec), factory.generatePrivate(privateKeySpec)); } }); } private static class RSACertificateDecoder extends CertificateDecoder { @Override public String getStartLine() { return "-----BEGIN RSA PRIVATE KEY-----"; } @Override public String getEndLine() { return "-----END RSA PRIVATE KEY-----"; } @Override protected KeyPair createKeyPair(PEMStructure pemStructure) throws IOException { SimpleDERReader dr = new SimpleDERReader(pemStructure.getData()); byte[] seq = dr.readSequenceAsByteArray(); if (dr.available() != 0) throw new IOException("Padding in RSA PRIVATE KEY DER stream."); dr.resetInput(seq); BigInteger version = dr.readInt(); if ((version.compareTo(BigInteger.ZERO) != 0) && (version.compareTo(BigInteger.ONE) != 0)) throw new IOException("Wrong version (" + version + ") in RSA PRIVATE KEY DER stream."); BigInteger n = dr.readInt(); BigInteger e = dr.readInt(); BigInteger d = dr.readInt(); BigInteger p = dr.readInt(); BigInteger q = dr.readInt(); BigInteger pE = dr.readInt(); BigInteger qE = dr.readInt(); BigInteger c = dr.readInt(); try { RSAPrivateKeySpec privateKeySpec = new RSAPrivateCrtKeySpec(n, e, d, p, q, pE, qE, c); RSAPublicKeySpec publicKeySpec = new RSAPublicKeySpec(n, e); KeyFactory factory = KeyFactory.getInstance("RSA"); PrivateKey privateKey = factory.generatePrivate(privateKeySpec); PublicKey publicKey = factory.generatePublic(publicKeySpec); return new KeyPair(publicKey, privateKey); } catch (GeneralSecurityException ex) { throw new IOException("Could not decode RSA Key Pair"); } } } }