/**
* TLS-Attacker - A Modular Penetration Testing Framework for TLS
*
* Copyright 2014-2016 Ruhr University Bochum / Hackmanit GmbH
*
* Licensed under Apache License 2.0
* http://www.apache.org/licenses/LICENSE-2.0
*/
package de.rub.nds.tlsattacker.dtls.protocol.handshake;
import de.rub.nds.tlsattacker.tls.record.Record;
import de.rub.nds.tlsattacker.tls.exceptions.MalformedMessageException;
import de.rub.nds.tlsattacker.tls.protocol.ProtocolMessage;
import de.rub.nds.tlsattacker.util.ArrayConverter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* @author Florian Pfützenreuter <florian.pfuetzenreuter@rub.de>
*/
public class HandshakeFragmentHandler {
private static final Logger LOGGER = LogManager.getLogger(HandshakeFragmentHandler.class);
final Map<Integer, List<Record>> handshakeMessageRecordMap = new HashMap<>();
final Map<Integer, BitSet> handshakeMessageReassembleBitmaskMap = new HashMap<>();
final Map<Integer, byte[]> reassembledHandshakeMessageMap = new HashMap<>();
private int expectedHandshakeMessageSeq;
public void processHandshakeRecord(Record handshakeRecord) {
byte[] recordData = handshakeRecord.getProtocolMessageBytes().getValue();
List<Integer> affectedHandshakeMessages = new ArrayList<>();
int workPointer = 0;
byte handshakeMessageType;
int handshakeMessageSize;
int handshakeMessageSeq;
int handshakeMessageFragOffset;
int handshakeMessageFragSize;
while ((workPointer + 12) <= recordData.length) {
handshakeMessageSeq = (recordData[workPointer + 4] << 8) + (recordData[workPointer + 5] & 0xFF);
handshakeMessageFragSize = (recordData[workPointer + 9] << 16) + (recordData[workPointer + 10] << 8)
+ (recordData[workPointer + 11] & 0xFF);
if (handshakeMessageSeq < expectedHandshakeMessageSeq
|| checkHandshakeMessageAvailable(handshakeMessageSeq)) {
workPointer += handshakeMessageFragSize + 12;
continue;
}
handshakeMessageFragOffset = (recordData[workPointer + 6] << 16) + (recordData[workPointer + 7] << 8)
+ (recordData[workPointer + 8] & 0xFF);
handshakeMessageType = recordData[workPointer];
handshakeMessageSize = (recordData[workPointer + 1] << 16) + (recordData[workPointer + 2] << 8)
+ (recordData[workPointer + 3] & 0xFF);
workPointer += 12;
if ((handshakeMessageFragSize + workPointer) > recordData.length) {
throw new MalformedMessageException(
"The received handshake message (fragment) claims to contain more data than it actually does.");
}
if (handshakeMessageFragSize > handshakeMessageSize) {
throw new MalformedMessageException(
"The received handshake message (fragment) claims to contain a fragment that's bigger than the actual handshake message length.");
}
if ((handshakeMessageFragOffset + handshakeMessageFragSize) > handshakeMessageSize) {
throw new MalformedMessageException(
"The received handshake message fragment is out of the the handshake message bounds implicated by its handshake message length.");
}
if (!affectedHandshakeMessages.contains(handshakeMessageSeq)) {
affectedHandshakeMessages.add(handshakeMessageSeq);
}
processHandshakeMessageFragment(handshakeMessageType, handshakeMessageSize, handshakeMessageSeq,
handshakeMessageFragOffset, handshakeMessageFragSize, recordData, workPointer);
workPointer += handshakeMessageFragSize;
}
for (Integer affectedHandshakeMessage : affectedHandshakeMessages) {
addHandshakeRecordToRecordMap(handshakeRecord, affectedHandshakeMessage);
}
}
private void processHandshakeMessageFragment(byte handshakeMessageType, int handshakeMessageSize,
int handshakeMessageSeq, int handshakeMessageFragOffset, int handshakeMessageFragSize, byte[] recordData,
int workPointer) {
if (createKeyInReassembleMaps(handshakeMessageSize, handshakeMessageSeq)) {
byte[] header = createCompleteHandshakeMessageHeader(handshakeMessageType, handshakeMessageSeq,
handshakeMessageSize);
handshakeMessageReassembleBitmaskMap.get(handshakeMessageSeq).set(0, 12, true);
System.arraycopy(header, 0, reassembledHandshakeMessageMap.get(handshakeMessageSeq), 0, 12);
}
handshakeMessageReassembleBitmaskMap.get(handshakeMessageSeq).set(handshakeMessageFragOffset + 12,
(handshakeMessageFragOffset + 12 + handshakeMessageFragSize), true);
System.arraycopy(recordData, workPointer, reassembledHandshakeMessageMap.get(handshakeMessageSeq),
handshakeMessageFragOffset + 12, handshakeMessageFragSize);
}
protected byte[] createCompleteHandshakeMessageHeader(byte handshakeType, int handshakeMessageSeq,
int handshakeMessageSize) {
byte[] output = new byte[12];
output[0] = handshakeType;
output[1] = (byte) (handshakeMessageSize >>> 16);
output[2] = (byte) (handshakeMessageSize >>> 8);
output[3] = (byte) handshakeMessageSize;
output[4] = (byte) (handshakeMessageSeq >>> 8);
output[5] = (byte) handshakeMessageSeq;
output[9] = output[1];
output[10] = output[2];
output[11] = output[3];
return output;
}
public byte[] getHandshakeMessage() {
if (checkHandshakeMessageAvailable(expectedHandshakeMessageSeq)) {
return reassembledHandshakeMessageMap.get(expectedHandshakeMessageSeq);
} else {
return null;
}
}
protected boolean checkHandshakeMessageAvailable(int seqNum) {
if (reassembledHandshakeMessageMap.containsKey(seqNum)) {
return checkHandshakeMessageCompleteness(seqNum);
}
return false;
}
private boolean checkHandshakeMessageCompleteness(int seqNum) {
return handshakeMessageReassembleBitmaskMap.get(seqNum).cardinality() == reassembledHandshakeMessageMap
.get(seqNum).length;
}
private boolean createKeyInReassembleMaps(int handshakeMessageSize, int seqNum) {
if (!handshakeMessageReassembleBitmaskMap.containsKey(seqNum)) {
handshakeMessageReassembleBitmaskMap.put(seqNum, new BitSet(handshakeMessageSize + 12));
reassembledHandshakeMessageMap.put(seqNum, new byte[handshakeMessageSize + 12]);
return true;
}
return false;
}
private void addHandshakeRecordToRecordMap(Record record, int seqNum) {
if (handshakeMessageRecordMap.containsKey(seqNum)) {
handshakeMessageRecordMap.get(seqNum).add(record);
} else {
ArrayList<Record> recordList = new ArrayList<>();
recordList.add(record);
handshakeMessageRecordMap.put(seqNum, recordList);
}
}
public void addRecordsToHandshakeMessage(ProtocolMessage handshakeMessage) {
List<Record> recordList = handshakeMessageRecordMap.get(expectedHandshakeMessageSeq);
handshakeMessage.setRecords(recordList);
}
public void flush() {
handshakeMessageRecordMap.clear();
handshakeMessageReassembleBitmaskMap.clear();
reassembledHandshakeMessageMap.clear();
}
public void incrementExpectedHandshakeMessageSeq() {
expectedHandshakeMessageSeq++;
}
public byte[] fragmentHandshakeMessage(byte[] handshakeMessageBytes, int maxMessageSize) {
maxMessageSize -= 12;
int messageSize = handshakeMessageBytes.length - 12;
int numFragments = (int) Math.ceil((double) messageSize / maxMessageSize);
if (numFragments == 0) {
numFragments = 1;
}
LOGGER.debug("Splitting the handshake message into {} fragments", numFragments);
byte[] fragmentArray = new byte[0];
int indexPointer, fragmentLength, fragmentSizeCounter;
byte[] handshakeHeader = new byte[12];
handshakeHeader[0] = handshakeMessageBytes[0];
handshakeHeader[1] = (byte) (messageSize >>> 16);
handshakeHeader[2] = (byte) (messageSize >>> 8);
handshakeHeader[3] = (byte) messageSize;
handshakeHeader[4] = handshakeMessageBytes[4];
handshakeHeader[5] = handshakeMessageBytes[5];
for (int i = 0; i < numFragments; i++) {
indexPointer = i * maxMessageSize;
fragmentSizeCounter = messageSize - maxMessageSize * i;
if (fragmentSizeCounter < maxMessageSize) {
fragmentLength = fragmentSizeCounter;
} else {
fragmentLength = maxMessageSize;
}
handshakeHeader[6] = (byte) (indexPointer >>> 16);
handshakeHeader[7] = (byte) (indexPointer >>> 8);
handshakeHeader[8] = (byte) indexPointer;
handshakeHeader[9] = (byte) (fragmentLength >>> 16);
handshakeHeader[10] = (byte) (fragmentLength >>> 8);
handshakeHeader[11] = (byte) fragmentLength;
fragmentArray = ArrayConverter.concatenate(fragmentArray, handshakeHeader,
Arrays.copyOfRange(handshakeMessageBytes, indexPointer + 12, indexPointer + fragmentLength + 12));
}
return fragmentArray;
}
public List<Record> getReceivedHandshakeMessageRecords(int seqNum) {
if (handshakeMessageRecordMap.containsKey(seqNum)) {
return handshakeMessageRecordMap.get(seqNum);
}
return new ArrayList<>();
}
public void setExpectedHandshakeMessageSeq(int seqNum) {
expectedHandshakeMessageSeq = seqNum;
}
public int getExpectedHandshakeMessageSeq() {
return expectedHandshakeMessageSeq;
}
}