package org.bouncycastle.crypto.tls; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import org.bouncycastle.util.Arrays; import org.bouncycastle.util.io.Streams; public class HeartbeatMessage { protected short type; protected byte[] payload; protected int paddingLength; public HeartbeatMessage(short type, byte[] payload, int paddingLength) { if (!HeartbeatMessageType.isValid(type)) { throw new IllegalArgumentException("'type' is not a valid HeartbeatMessageType value"); } if (payload == null || payload.length >= (1 << 16)) { throw new IllegalArgumentException("'payload' must have length < 2^16"); } if (paddingLength < 16) { throw new IllegalArgumentException("'paddingLength' must be at least 16"); } this.type = type; this.payload = payload; this.paddingLength = paddingLength; } /** * Encode this {@link HeartbeatMessage} to an {@link OutputStream}. * * @param output * the {@link OutputStream} to encode to. * @throws IOException */ public void encode(TlsContext context, OutputStream output) throws IOException { TlsUtils.writeUint8(type, output); TlsUtils.checkUint16(payload.length); TlsUtils.writeUint16(payload.length, output); output.write(payload); byte[] padding = new byte[paddingLength]; context.getSecureRandom().nextBytes(padding); output.write(padding); } /** * Parse a {@link HeartbeatMessage} from an {@link InputStream}. * * @param input * the {@link InputStream} to parse from. * @return a {@link HeartbeatMessage} object. * @throws IOException */ public static HeartbeatMessage parse(InputStream input) throws IOException { short type = TlsUtils.readUint8(input); if (!HeartbeatMessageType.isValid(type)) { throw new TlsFatalAlert(AlertDescription.illegal_parameter); } int payload_length = TlsUtils.readUint16(input); PayloadBuffer buf = new PayloadBuffer(); Streams.pipeAll(input, buf); byte[] payload = buf.toTruncatedByteArray(payload_length); if (payload == null) { /* * RFC 6520 4. If the payload_length of a received HeartbeatMessage is too large, the * received HeartbeatMessage MUST be discarded silently. */ return null; } int padding_length = buf.size() - payload.length; return new HeartbeatMessage(type, payload, padding_length); } static class PayloadBuffer extends ByteArrayOutputStream { byte[] toTruncatedByteArray(int payloadLength) { /* * RFC 6520 4. The padding_length MUST be at least 16. */ int minimumCount = payloadLength + 16; if (count < minimumCount) { return null; } return Arrays.copyOf(buf, payloadLength); } } }