/* * Copyright (C)2009 - SSHJ Contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.schmizz.sshj.transport.kex; import net.schmizz.sshj.common.*; import net.schmizz.sshj.signature.Signature; import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.digest.Digest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.crypto.spec.DHParameterSpec; import java.math.BigInteger; import java.security.GeneralSecurityException; public abstract class AbstractDHGex extends AbstractDH { private final Logger log = LoggerFactory.getLogger(getClass()); private int minBits = 1024; private int maxBits = 8192; private int preferredBits = 2048; public AbstractDHGex(Digest digest) { super(new DH(), digest); } @Override public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException { super.init(trans, V_S, V_C, I_S, I_C); digest.init(); log.debug("Sending {}", Message.KEX_DH_GEX_REQUEST); trans.write(new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(minBits).putUInt32(preferredBits).putUInt32(maxBits)); } @Override public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException { log.debug("Got message {}", msg); try { switch (msg) { case KEXDH_31: return parseGexGroup(buffer); case KEX_DH_GEX_REPLY: return parseGexReply(buffer); } } catch (Buffer.BufferException be) { throw new TransportException(be); } throw new TransportException("Unexpected message " + msg); } private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException { byte[] K_S = buffer.readBytes(); byte[] f = buffer.readBytes(); byte[] sig = buffer.readBytes(); hostKey = new Buffer.PlainBuffer(K_S).readPublicKey(); dh.computeK(f); BigInteger k = dh.getK(); final Buffer.PlainBuffer buf = initializedBuffer() .putString(K_S) .putUInt32(minBits) .putUInt32(preferredBits) .putUInt32(maxBits) .putMPInt(((DH) dh).getP()) .putMPInt(((DH) dh).getG()) .putBytes(dh.getE()) .putBytes(f) .putMPInt(k); digest.update(buf.array(), buf.rpos(), buf.available()); H = digest.digest(); Signature signature = Factory.Named.Util.create(trans.getConfig().getSignatureFactories(), KeyType.fromKey(hostKey).toString()); signature.init(hostKey, null); signature.update(H, 0, H.length); if (!signature.verify(sig)) throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "KeyExchange signature verification failed"); return true; } private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException { BigInteger p = buffer.readMPInt(); BigInteger g = buffer.readMPInt(); int bitLength = p.bitLength(); if (bitLength < minBits || bitLength > maxBits) { throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)"); } log.debug("Received server p bitlength {}", bitLength); dh.init(new DHParameterSpec(p, g), trans.getConfig().getRandomFactory()); log.debug("Sending {}", Message.KEX_DH_GEX_INIT); trans.write(new SSHPacket(Message.KEX_DH_GEX_INIT).putBytes(dh.getE())); return false; } }