/** * TLS-Attacker - A Modular Penetration Testing Framework for TLS * * Copyright 2014-2016 Ruhr University Bochum / Hackmanit GmbH * * Licensed under Apache License 2.0 * http://www.apache.org/licenses/LICENSE-2.0 */ package de.rub.nds.tlsattacker.tls.protocol.handshake; import java.util.Arrays; import de.rub.nds.tlsattacker.tls.constants.CipherSuite; import de.rub.nds.tlsattacker.tls.constants.CompressionMethod; import de.rub.nds.tlsattacker.tls.constants.ExtensionByteLength; import de.rub.nds.tlsattacker.tls.constants.ExtensionType; import de.rub.nds.tlsattacker.tls.constants.HandshakeByteLength; import de.rub.nds.tlsattacker.tls.constants.HandshakeMessageType; import de.rub.nds.tlsattacker.tls.constants.ProtocolVersion; import de.rub.nds.tlsattacker.tls.constants.RecordByteLength; import de.rub.nds.tlsattacker.tls.exceptions.InvalidMessageTypeException; import de.rub.nds.tlsattacker.tls.exceptions.WorkflowExecutionException; import de.rub.nds.tlsattacker.tls.protocol.extension.ExtensionHandler; import de.rub.nds.tlsattacker.tls.protocol.extension.ExtensionMessage; import de.rub.nds.tlsattacker.tls.workflow.TlsContext; import de.rub.nds.tlsattacker.util.ArrayConverter; import de.rub.nds.tlsattacker.util.RandomHelper; import de.rub.nds.tlsattacker.util.Time; /** * @author Juraj Somorovsky <juraj.somorovsky@rub.de> * @author Philip Riese <philip.riese@rub.de> * @param <HandshakeMessage> */ public class ClientHelloHandler<Message extends ClientHelloMessage> extends HandshakeMessageHandler<Message> { @SuppressWarnings("unchecked") public ClientHelloHandler(TlsContext tlsContext) { super(tlsContext); this.correctProtocolMessageClass = (Class<? extends Message>) ClientHelloMessage.class; } @Override public byte[] prepareMessageAction() { protocolMessage.setProtocolVersion(tlsContext.getProtocolVersion().getValue()); // supporting Session Resumption with Session IDs if (tlsContext.isSessionResumption()) { protocolMessage.setSessionId(tlsContext.getSessionID()); } else { // by default we do not use a session id protocolMessage.setSessionId(new byte[0]); } int length = protocolMessage.getSessionId().getValue().length; protocolMessage.setSessionIdLength(length); if (tlsContext.isMitMAttack()) { protocolMessage.setUnixTime(protocolMessage.getUnixTime()); protocolMessage.setRandom(protocolMessage.getRandom()); } else { // random handling final long unixTime = Time.getUnixTime(); protocolMessage.setUnixTime(ArrayConverter.longToUint32Bytes(unixTime)); byte[] random = new byte[HandshakeByteLength.RANDOM]; RandomHelper.getRandom().nextBytes(random); protocolMessage.setRandom(random); } tlsContext.setClientRandom(ArrayConverter.concatenate(protocolMessage.getUnixTime().getValue(), protocolMessage .getRandom().getValue())); byte[] cookieArray = new byte[0]; if (tlsContext.getProtocolVersion() == ProtocolVersion.DTLS12 || tlsContext.getProtocolVersion() == ProtocolVersion.DTLS10) { de.rub.nds.tlsattacker.dtls.protocol.handshake.ClientHelloDtlsMessage dtlsClientHello = (de.rub.nds.tlsattacker.dtls.protocol.handshake.ClientHelloDtlsMessage) protocolMessage; dtlsClientHello.setCookie(tlsContext.getDtlsHandshakeCookie()); dtlsClientHello.setCookieLength((byte) tlsContext.getDtlsHandshakeCookie().length); cookieArray = ArrayConverter.concatenate(new byte[] { dtlsClientHello.getCookieLength().getValue() }, dtlsClientHello.getCookie().getValue()); } byte[] cipherSuites = null; for (CipherSuite cs : protocolMessage.getSupportedCipherSuites()) { cipherSuites = ArrayConverter.concatenate(cipherSuites, cs.getByteValue()); } protocolMessage.setCipherSuites(cipherSuites); int cipherSuiteLength = protocolMessage.getCipherSuites().getValue().length; protocolMessage.setCipherSuiteLength(cipherSuiteLength); byte[] compressionMethods = null; for (CompressionMethod cm : protocolMessage.getSupportedCompressionMethods()) { compressionMethods = ArrayConverter.concatenate(compressionMethods, cm.getArrayValue()); } protocolMessage.setCompressions(compressionMethods); int compressionMethodLength = protocolMessage.getCompressions().getValue().length; protocolMessage.setCompressionLength(compressionMethodLength); byte[] result = ArrayConverter.concatenate(protocolMessage.getProtocolVersion().getValue(), protocolMessage .getUnixTime().getValue(), protocolMessage.getRandom().getValue(), ArrayConverter.intToBytes( protocolMessage.getSessionIdLength().getValue(), 1), protocolMessage.getSessionId().getValue(), cookieArray, ArrayConverter.intToBytes(protocolMessage.getCipherSuiteLength().getValue(), HandshakeByteLength.CIPHER_SUITE), protocolMessage.getCipherSuites().getValue(), ArrayConverter.intToBytes(protocolMessage.getCompressionLength().getValue(), HandshakeByteLength.COMPRESSION), protocolMessage.getCompressions().getValue()); byte[] extensionBytes = null; if (tlsContext.isMitMAttack()) { extensionBytes = protocolMessage.getExtensionBytes(); result = ArrayConverter.concatenate(result, extensionBytes); } else { for (ExtensionMessage extension : protocolMessage.getExtensions()) { ExtensionHandler handler = extension.getExtensionHandler(); handler.initializeClientHelloExtension(extension); extensionBytes = ArrayConverter.concatenate(extensionBytes, extension.getExtensionBytes().getValue()); } if (extensionBytes != null && extensionBytes.length != 0) { byte[] extensionLength = ArrayConverter.intToBytes(extensionBytes.length, ExtensionByteLength.EXTENSIONS); result = ArrayConverter.concatenate(result, extensionLength, extensionBytes); } } protocolMessage.setLength(result.length); long header = (HandshakeMessageType.CLIENT_HELLO.getValue() << 24) + protocolMessage.getLength().getValue(); protocolMessage.setCompleteResultingMessage(ArrayConverter.concatenate( ArrayConverter.longToUint32Bytes(header), result)); return protocolMessage.getCompleteResultingMessage().getValue(); } @Override public int parseMessageAction(byte[] message, int pointer) { if (message[pointer] != HandshakeMessageType.CLIENT_HELLO.getValue()) { throw new InvalidMessageTypeException("This is not a client hello message"); } protocolMessage.setType(message[pointer]); int currentPointer = pointer + HandshakeByteLength.MESSAGE_TYPE; int nextPointer = currentPointer + HandshakeByteLength.MESSAGE_TYPE_LENGTH; int length = ArrayConverter.bytesToInt(Arrays.copyOfRange(message, currentPointer, nextPointer)); protocolMessage.setLength(length); currentPointer = nextPointer; nextPointer = currentPointer + RecordByteLength.PROTOCOL_VERSION; ProtocolVersion serverProtocolVersion = ProtocolVersion.getProtocolVersion(Arrays.copyOfRange(message, currentPointer, nextPointer)); protocolMessage.setProtocolVersion(serverProtocolVersion.getValue()); currentPointer = nextPointer; nextPointer = currentPointer + HandshakeByteLength.UNIX_TIME; protocolMessage.setUnixTime(Arrays.copyOfRange(message, currentPointer, nextPointer)); currentPointer = nextPointer; nextPointer = currentPointer + HandshakeByteLength.RANDOM; protocolMessage.setRandom(Arrays.copyOfRange(message, currentPointer, nextPointer)); tlsContext.setClientRandom(ArrayConverter.concatenate(protocolMessage.getUnixTime().getValue(), protocolMessage .getRandom().getValue())); currentPointer = nextPointer; nextPointer += HandshakeByteLength.SESSION_ID_LENGTH; int sessionIdLength = ArrayConverter.bytesToInt(Arrays.copyOfRange(message, currentPointer, nextPointer)); protocolMessage.setSessionIdLength(sessionIdLength); currentPointer = nextPointer; nextPointer += sessionIdLength; protocolMessage.setSessionId(Arrays.copyOfRange(message, currentPointer, nextPointer)); // handle unknown SessionID during Session resumption if (tlsContext.isSessionResumption() && !(Arrays.equals(tlsContext.getSessionID(), protocolMessage.getSessionId().getValue()))) { throw new WorkflowExecutionException("Session ID is unknown to the Server"); } if (tlsContext.getProtocolVersion() == ProtocolVersion.DTLS12 || tlsContext.getProtocolVersion() == ProtocolVersion.DTLS10) { de.rub.nds.tlsattacker.dtls.protocol.handshake.ClientHelloDtlsMessage dtlsClientHello = (de.rub.nds.tlsattacker.dtls.protocol.handshake.ClientHelloDtlsMessage) protocolMessage; currentPointer = nextPointer; nextPointer += HandshakeByteLength.DTLS_HANDSHAKE_COOKIE_LENGTH; byte cookieLength = message[currentPointer]; dtlsClientHello.setCookieLength(cookieLength); currentPointer = nextPointer; nextPointer += cookieLength; dtlsClientHello.setCookie(Arrays.copyOfRange(message, currentPointer, nextPointer)); } currentPointer = nextPointer; nextPointer += HandshakeByteLength.CIPHER_SUITE; int cipherSuitesLength = ArrayConverter.bytesToInt(Arrays.copyOfRange(message, currentPointer, nextPointer)); protocolMessage.setCipherSuiteLength(cipherSuitesLength); currentPointer = nextPointer; nextPointer += cipherSuitesLength; protocolMessage.setCipherSuites(Arrays.copyOfRange(message, currentPointer, nextPointer)); currentPointer = nextPointer; nextPointer += HandshakeByteLength.COMPRESSION; int compressionsLength = ArrayConverter.bytesToInt(Arrays.copyOfRange(message, currentPointer, nextPointer)); protocolMessage.setCompressionLength(compressionsLength); currentPointer = nextPointer; nextPointer += compressionsLength; protocolMessage.setCompressions(Arrays.copyOfRange(message, currentPointer, nextPointer)); byte[] compression = protocolMessage.getCompressions().getValue(); tlsContext.setCompressionMethod(CompressionMethod.getCompressionMethod(compression[0])); currentPointer = nextPointer; if ((currentPointer - pointer) < length) { currentPointer += ExtensionByteLength.EXTENSIONS; while ((currentPointer - pointer) < length) { nextPointer = currentPointer + ExtensionByteLength.TYPE; byte[] extensionType = Arrays.copyOfRange(message, currentPointer, nextPointer); // Not implemented/unknown extensions will generate an Exception // ... try { ExtensionHandler<? extends ExtensionMessage> eh = ExtensionType.getExtensionType(extensionType).getExtensionHandler(); currentPointer = eh.parseExtension(message, currentPointer); protocolMessage.addExtension(eh.getExtensionMessage()); } // ... which we catch, then disregard that extension and carry // on. catch (Exception ex) { currentPointer = nextPointer; nextPointer += 2; currentPointer += ArrayConverter.bytesToInt(Arrays .copyOfRange(message, currentPointer, nextPointer)); nextPointer += 2; currentPointer += 2; } } } protocolMessage.setCompleteResultingMessage(Arrays.copyOfRange(message, pointer, currentPointer)); return (currentPointer - pointer); } }