/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.server.kex; import java.io.FileNotFoundException; import java.io.IOException; import java.math.BigInteger; import java.net.URL; import java.security.KeyPair; import java.util.ArrayList; import java.util.List; import java.util.Objects; import org.apache.sshd.common.Factory; import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.kex.DHFactory; import org.apache.sshd.common.kex.DHG; import org.apache.sshd.common.kex.DHGroupData; import org.apache.sshd.common.kex.KexProposalOption; import org.apache.sshd.common.kex.KeyExchange; import org.apache.sshd.common.kex.KeyExchangeFactory; import org.apache.sshd.common.random.Random; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.signature.Signature; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.buffer.BufferUtils; import org.apache.sshd.common.util.buffer.ByteArrayBuffer; import org.apache.sshd.common.util.security.SecurityUtils; import org.apache.sshd.server.ServerFactoryManager; import org.apache.sshd.server.session.ServerSession; /** * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public class DHGEXServer extends AbstractDHServerKeyExchange { protected final DHFactory factory; protected DHG dh; protected int min; protected int prf; protected int max; protected byte expected; protected boolean oldRequest; protected DHGEXServer(DHFactory factory) { this.factory = Objects.requireNonNull(factory, "No factory"); } @Override public final String getName() { return factory.getName(); } public static KeyExchangeFactory newFactory(final DHFactory factory) { return new KeyExchangeFactory() { @Override public KeyExchange create() { return new DHGEXServer(factory); } @Override public String getName() { return factory.getName(); } @Override public String toString() { return NamedFactory.class.getSimpleName() + "<" + KeyExchange.class.getSimpleName() + ">" + "[" + getName() + "]"; } }; } @Override public void init(Session s, byte[] v_s, byte[] v_c, byte[] i_s, byte[] i_c) throws Exception { super.init(s, v_s, v_c, i_s, i_c); expected = SshConstants.SSH_MSG_KEX_DH_GEX_REQUEST; } @Override public boolean next(int cmd, Buffer buffer) throws Exception { ServerSession session = getServerSession(); if (log.isDebugEnabled()) { log.debug("next({})[{}] process command={}", this, session, KeyExchange.getGroupKexOpcodeName(cmd)); } if (cmd == SshConstants.SSH_MSG_KEX_DH_GEX_REQUEST_OLD && expected == SshConstants.SSH_MSG_KEX_DH_GEX_REQUEST) { oldRequest = true; min = SecurityUtils.MIN_DHGEX_KEY_SIZE; prf = buffer.getInt(); max = SecurityUtils.getMaxDHGroupExchangeKeySize(); if (max < min || prf < min || max < prf) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "Protocol error: bad parameters " + min + " !< " + prf + " !< " + max); } dh = chooseDH(min, prf, max); f = dh.getE(); hash = dh.getHash(); hash.init(); if (log.isDebugEnabled()) { log.debug("next({})[{}] send SSH_MSG_KEX_DH_GEX_GROUP", this, session); } buffer = session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP); buffer.putMPInt(dh.getP()); buffer.putMPInt(dh.getG()); session.writePacket(buffer); expected = SshConstants.SSH_MSG_KEX_DH_GEX_INIT; return false; } if (cmd == SshConstants.SSH_MSG_KEX_DH_GEX_REQUEST && expected == SshConstants.SSH_MSG_KEX_DH_GEX_REQUEST) { min = buffer.getInt(); prf = buffer.getInt(); max = buffer.getInt(); if (prf < min || max < prf) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "Protocol error: bad parameters " + min + " !< " + prf + " !< " + max); } dh = chooseDH(min, prf, max); f = dh.getE(); hash = dh.getHash(); hash.init(); if (log.isDebugEnabled()) { log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_GROUP", this, session); } buffer = session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP); buffer.putMPInt(dh.getP()); buffer.putMPInt(dh.getG()); session.writePacket(buffer); expected = SshConstants.SSH_MSG_KEX_DH_GEX_INIT; return false; } if (cmd != expected) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "Protocol error: expected packet " + KeyExchange.getGroupKexOpcodeName(expected) + ", got " + KeyExchange.getGroupKexOpcodeName(cmd)); } if (cmd == SshConstants.SSH_MSG_KEX_DH_GEX_INIT) { e = buffer.getMPIntAsBytes(); dh.setF(e); k = dh.getK(); byte[] k_s; KeyPair kp = Objects.requireNonNull(session.getHostKey(), "No server key pair available"); String algo = session.getNegotiatedKexParameter(KexProposalOption.SERVERKEYS); Signature sig = ValidateUtils.checkNotNull( NamedFactory.create(session.getSignatureFactories(), algo), "Unknown negotiated server keys: %s", algo); sig.initSigner(kp.getPrivate()); buffer = new ByteArrayBuffer(); buffer.putRawPublicKey(kp.getPublic()); k_s = buffer.getCompactData(); buffer.clear(); buffer.putBytes(v_c); buffer.putBytes(v_s); buffer.putBytes(i_c); buffer.putBytes(i_s); buffer.putBytes(k_s); if (oldRequest) { buffer.putInt(prf); } else { buffer.putInt(min); buffer.putInt(prf); buffer.putInt(max); } buffer.putMPInt(dh.getP()); buffer.putMPInt(dh.getG()); buffer.putMPInt(e); buffer.putMPInt(f); buffer.putMPInt(k); hash.update(buffer.array(), 0, buffer.available()); h = hash.digest(); sig.update(h); buffer.clear(); buffer.putString(algo); byte[] sigBytes = sig.sign(); buffer.putBytes(sigBytes); byte[] sigH = buffer.getCompactData(); if (log.isTraceEnabled()) { log.trace("next({})[{}][K_S]: {}", this, session, BufferUtils.toHex(k_s)); log.trace("next({})[{}][f]: {}", this, session, BufferUtils.toHex(f)); log.trace("next({})[{}][sigH]: {}", this, session, BufferUtils.toHex(sigH)); } // Send response if (log.isDebugEnabled()) { log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_REPLY", this, session); } buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_REPLY, BufferUtils.clear(buffer)); buffer.putBytes(k_s); buffer.putBytes(f); buffer.putBytes(sigH); session.writePacket(buffer); return true; } return false; } private DHG chooseDH(int min, int prf, int max) throws Exception { List<Moduli.DhGroup> groups = loadModuliGroups(); min = Math.max(min, SecurityUtils.MIN_DHGEX_KEY_SIZE); prf = Math.max(prf, SecurityUtils.MIN_DHGEX_KEY_SIZE); prf = Math.min(prf, SecurityUtils.getMaxDHGroupExchangeKeySize()); max = Math.min(max, SecurityUtils.getMaxDHGroupExchangeKeySize()); int bestSize = 0; List<Moduli.DhGroup> selected = new ArrayList<>(); for (Moduli.DhGroup group : groups) { if (group.size < min || group.size > max) { continue; } if ((group.size > prf && group.size < bestSize) || (group.size > bestSize && bestSize < prf)) { bestSize = group.size; selected.clear(); } if (group.size == bestSize) { selected.add(group); } } ServerSession session = getServerSession(); if (selected.isEmpty()) { log.warn("chooseDH({})[{}] No suitable primes found, defaulting to DHG1", this, session); return getDH(new BigInteger(DHGroupData.getP1()), new BigInteger(DHGroupData.getG())); } FactoryManager manager = Objects.requireNonNull(session.getFactoryManager(), "No factory manager"); Factory<Random> factory = Objects.requireNonNull(manager.getRandomFactory(), "No random factory"); Random random = Objects.requireNonNull(factory.create(), "No random generator"); int which = random.random(selected.size()); Moduli.DhGroup group = selected.get(which); return getDH(group.p, group.g); } protected List<Moduli.DhGroup> loadModuliGroups() throws IOException { ServerSession session = getServerSession(); String moduliStr = session.getString(ServerFactoryManager.MODULI_URL); List<Moduli.DhGroup> groups = null; URL moduli; if (!GenericUtils.isEmpty(moduliStr)) { try { moduli = new URL(moduliStr); groups = Moduli.parseModuli(moduli); } catch (IOException e) { // OK - use internal moduli log.warn("Error (" + e.getClass().getSimpleName() + ") loading external moduli from " + moduliStr + ": " + e.getMessage()); } } if (groups == null) { moduliStr = "/org/apache/sshd/moduli"; try { moduli = getClass().getResource(moduliStr); if (moduli == null) { throw new FileNotFoundException("Missing internal moduli file"); } moduliStr = moduli.toExternalForm(); groups = Moduli.parseModuli(moduli); } catch (IOException e) { log.warn("Error (" + e.getClass().getSimpleName() + ") loading internal moduli from " + moduliStr + ": " + e.getMessage()); throw e; // this time we MUST throw the exception } } if (log.isDebugEnabled()) { log.debug("Loaded moduli groups from {}", moduliStr); } return groups; } protected DHG getDH(BigInteger p, BigInteger g) throws Exception { return (DHG) factory.create(p, g); } }