/*
* TeleStax, Open Source Cloud Communications
* Copyright 2011-2015, Telestax Inc and individual contributors
* by the @authors tag.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.restcomm.media.rtp.secure;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.DatagramChannel;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.log4j.Logger;
import org.bouncycastle.crypto.tls.DTLSServerProtocol;
import org.bouncycastle.crypto.tls.DatagramTransport;
import org.restcomm.media.network.deprecated.channel.PacketHandler;
import org.restcomm.media.network.deprecated.channel.PacketHandlerException;
import org.restcomm.media.rtp.crypto.DtlsSrtpServer;
import org.restcomm.media.rtp.crypto.DtlsSrtpServerProvider;
import org.restcomm.media.rtp.crypto.PacketTransformer;
import org.restcomm.media.rtp.crypto.SRTPPolicy;
import org.restcomm.media.rtp.crypto.SRTPTransformEngine;
/**
* Handler to process DTLS packets.
*
* @author Henrique Rosa (henrique.rosa@telestax.com)
*
*/
public class DtlsHandler implements PacketHandler, DatagramTransport {
private static final AtomicLong THREAD_COUNTER = new AtomicLong(0);
private static final Logger logger = Logger.getLogger(DtlsHandler.class);
public static final int DEFAULT_MTU = 1500;
private final static int MIN_IP_OVERHEAD = 20;
private final static int MAX_IP_OVERHEAD = MIN_IP_OVERHEAD + 64;
private final static int UDP_OVERHEAD = 8;
public final static int MAX_DELAY = 10000;
// Packet Handler properties
private int pipelinePriority;
// Network properties
private int mtu;
private final int receiveLimit;
private final int sendLimit;
// DTLS Handshake properties
private DtlsSrtpServer server;
private DatagramChannel channel;
private final Queue<ByteBuffer> rxQueue;
private volatile boolean handshakeComplete;
private volatile boolean handshakeFailed;
private volatile boolean handshaking;
private Thread worker;
private String localHashFunction;
private String remoteHashFunction;
private String remoteFingerprint;
private String localFingerprint;
private long startTime;
private final List<DtlsListener> listeners;
// SRTP properties
// http://tools.ietf.org/html/rfc5764#section-4.2
private PacketTransformer srtpEncoder;
private PacketTransformer srtpDecoder;
private PacketTransformer srtcpEncoder;
private PacketTransformer srtcpDecoder;
private DtlsSrtpServerProvider tlsServerProvider;
public DtlsHandler(DtlsSrtpServerProvider tlsServerProvider) {
this.pipelinePriority = 0;
// Network properties
this.mtu = DEFAULT_MTU;
this.receiveLimit = Math.max(0, mtu - MIN_IP_OVERHEAD - UDP_OVERHEAD);
this.sendLimit = Math.max(0, mtu - MAX_IP_OVERHEAD - UDP_OVERHEAD);
// Handshake properties
this.server = tlsServerProvider.provide();
this.rxQueue = new ConcurrentLinkedQueue<>();
this.handshakeComplete = false;
this.handshakeFailed = false;
this.handshaking = false;
this.localHashFunction = "SHA-256";
this.remoteHashFunction = "";
this.remoteFingerprint = "";
this.localFingerprint = "";
this.startTime = 0L;
this.listeners = new ArrayList<DtlsListener>();
this.tlsServerProvider = tlsServerProvider;
}
public void setChannel(DatagramChannel channel) {
this.channel = channel;
}
public void addListener(DtlsListener listener) {
if (!this.listeners.contains(listener)) {
this.listeners.add(listener);
}
}
public boolean isHandshakeComplete() {
return handshakeComplete;
}
public boolean isHandshakeFailed() {
return handshakeFailed;
}
public boolean isHandshaking() {
return handshaking;
}
public String getLocalFingerprint() {
if (this.localFingerprint == null || this.localFingerprint.isEmpty()) {
this.localFingerprint = this.server.generateFingerprint(this.localHashFunction);
}
return localFingerprint;
}
public void resetLocalFingerprint() {
this.localFingerprint = "";
}
public String getLocalHashFunction() {
return localHashFunction;
}
public String getRemoteHashFunction() {
return remoteHashFunction;
}
public String getRemoteFingerprintValue() {
return remoteFingerprint;
}
public String getRemoteFingerprint() {
return remoteHashFunction + " " + remoteFingerprint;
}
public void setRemoteFingerprint(String hashFunction, String fingerprint) {
this.remoteHashFunction = hashFunction;
this.remoteFingerprint = fingerprint;
}
private byte[] getMasterServerKey() {
return server.getSrtpMasterServerKey();
}
private byte[] getMasterServerSalt() {
return server.getSrtpMasterServerSalt();
}
private byte[] getMasterClientKey() {
return server.getSrtpMasterClientKey();
}
private byte[] getMasterClientSalt() {
return server.getSrtpMasterClientSalt();
}
private SRTPPolicy getSrtpPolicy() {
return server.getSrtpPolicy();
}
private SRTPPolicy getSrtcpPolicy() {
return server.getSrtcpPolicy();
}
public PacketTransformer getSrtpDecoder() {
return srtpDecoder;
}
public PacketTransformer getSrtpEncoder() {
return srtpEncoder;
}
public PacketTransformer getSrtcpDecoder() {
return srtcpDecoder;
}
public PacketTransformer getSrtcpEncoder() {
return srtcpEncoder;
}
/**
* Generates an SRTP encoder for outgoing RTP packets using keying material from the DTLS handshake.
*/
private PacketTransformer generateRtpEncoder() {
return new SRTPTransformEngine(getMasterServerKey(), getMasterServerSalt(), getSrtpPolicy(), getSrtcpPolicy())
.getRTPTransformer();
}
/**
* Generates an SRTP decoder for incoming RTP packets using keying material from the DTLS handshake.
*/
private PacketTransformer generateRtpDecoder() {
return new SRTPTransformEngine(getMasterClientKey(), getMasterClientSalt(), getSrtpPolicy(), getSrtcpPolicy())
.getRTPTransformer();
}
/**
* Generates an SRTCP encoder for outgoing RTCP packets using keying material from the DTLS handshake.
*/
private PacketTransformer generateRtcpEncoder() {
return new SRTPTransformEngine(getMasterServerKey(), getMasterServerSalt(), getSrtpPolicy(), getSrtcpPolicy())
.getRTCPTransformer();
}
/**
* Generates an SRTCP decoder for incoming RTCP packets using keying material from the DTLS handshake.
*/
private PacketTransformer generateRtcpDecoder() {
return new SRTPTransformEngine(getMasterClientKey(), getMasterClientSalt(), getSrtpPolicy(), getSrtcpPolicy())
.getRTCPTransformer();
}
/**
* Decodes an RTP Packet
*
* @param packet The encoded RTP packet
* @return The decoded RTP packet. Returns null is packet is not valid.
*/
public byte[] decodeRTP(byte[] packet, int offset, int length) {
return this.srtpDecoder.reverseTransform(packet, offset, length);
}
/**
* Encodes an RTP packet
*
* @param packet The decoded RTP packet
* @return The encoded RTP packet
*/
public byte[] encodeRTP(byte[] packet, int offset, int length) {
return this.srtpEncoder.transform(packet, offset, length);
}
/**
* Decodes an RTCP Packet
*
* @param packet The encoded RTP packet
* @return The decoded RTP packet. Returns null is packet is not valid.
*/
public byte[] decodeRTCP(byte[] packet, int offset, int length) {
return this.srtcpDecoder.reverseTransform(packet, offset, length);
}
/**
* Encodes an RTCP packet
*
* @param packet The decoded RTP packet
* @return The encoded RTP packet
*/
public byte[] encodeRTCP(byte[] packet, int offset, int length) {
return this.srtcpEncoder.transform(packet, offset, length);
}
public void handshake() {
if (!handshaking && !handshakeComplete) {
this.handshaking = true;
this.startTime = System.currentTimeMillis();
this.worker = new Thread(new HandshakeWorker(), "DTLS-Server-" + THREAD_COUNTER.incrementAndGet());
this.worker.start();
}
}
private void fireHandshakeComplete() {
if (this.listeners.size() > 0) {
Iterator<DtlsListener> iterator = listeners.iterator();
while (iterator.hasNext()) {
iterator.next().onDtlsHandshakeComplete();
}
}
}
private void fireHandshakeFailed(Throwable e) {
if (this.listeners.size() > 0) {
Iterator<DtlsListener> iterator = listeners.iterator();
while (iterator.hasNext()) {
iterator.next().onDtlsHandshakeFailed(e);
}
}
}
public void reset() {
// XXX try not to create the server every time!
this.server = this.tlsServerProvider.provide();
this.channel = null;
this.srtcpDecoder = null;
this.srtcpEncoder = null;
this.srtpDecoder = null;
this.srtpEncoder = null;
this.remoteHashFunction = "";
this.remoteFingerprint = "";
this.localFingerprint = "";
this.handshakeComplete = false;
this.handshakeFailed = false;
this.handshaking = false;
this.startTime = 0L;
this.listeners.clear();
}
@Override
public int compareTo(PacketHandler o) {
if (o == null) {
return 1;
}
return this.getPipelinePriority() - o.getPipelinePriority();
}
@Override
public boolean canHandle(byte[] packet) {
return canHandle(packet, packet.length, 0);
}
@Override
public boolean canHandle(byte[] packet, int dataLength, int offset) {
// https://tools.ietf.org/html/rfc5764#section-5.1.2
int contentType = packet[offset] & 0xff;
return (contentType > 19 && contentType < 64);
}
@Override
public byte[] handle(byte[] packet, InetSocketAddress localPeer, InetSocketAddress remotePeer)
throws PacketHandlerException {
return this.handle(packet, packet.length, 0, localPeer, remotePeer);
}
@Override
public byte[] handle(byte[] packet, int dataLength, int offset, InetSocketAddress localPeer, InetSocketAddress remotePeer)
throws PacketHandlerException {
this.rxQueue.offer(ByteBuffer.wrap(packet, offset, dataLength));
return null;
}
@Override
public int getPipelinePriority() {
return this.pipelinePriority;
}
public void setPipelinePriority(int pipelinePriority) {
this.pipelinePriority = pipelinePriority;
}
@Override
public int getReceiveLimit() throws IOException {
return this.receiveLimit;
}
@Override
public int getSendLimit() throws IOException {
return this.sendLimit;
}
@Override
public int receive(byte[] buf, int off, int len, int waitMillis) throws IOException {
// MEDIA-48: DTLS handshake thread does not terminate
// https://telestax.atlassian.net/browse/MEDIA-48
if (this.hasTimeout()) {
close();
throw new IllegalStateException("Handshake is taking too long! (>" + MAX_DELAY + "ms");
}
int attempts = waitMillis;
do {
ByteBuffer data = this.rxQueue.poll();
if (data != null) {
data.get(buf, off, data.limit());
return data.limit();
}
try {
Thread.sleep(1);
} catch (InterruptedException e) {
logger.warn("Could not sleep thread to receive DTLS data.");
} finally {
attempts--;
}
} while (attempts > 0);
// Throw IO exception if no data was received in this interval. Restarts outbound flight.
throw new SocketTimeoutException("Could not receive DTLS packet in " + waitMillis);
}
@Override
public void send(byte[] buf, int off, int len) throws IOException {
if (!hasTimeout()) {
if (this.channel != null && this.channel.isOpen() && this.channel.isConnected()) {
this.channel.send(ByteBuffer.wrap(buf, off, len), channel.getRemoteAddress());
} else {
logger.warn("Handler skipped send operation because channel is not open or connected.");
}
} else {
logger.warn("Handler has timed out so send operation will be skipped.");
}
}
@Override
public void close() throws IOException {
this.rxQueue.clear();
this.startTime = 0L;
this.channel = null;
}
private boolean hasTimeout() {
return (System.currentTimeMillis() - this.startTime) > MAX_DELAY;
}
private class HandshakeWorker implements Runnable {
public void run() {
DtlsHandler.this.rxQueue.clear();
SecureRandom secureRandom = new SecureRandom();
DTLSServerProtocol serverProtocol = new DTLSServerProtocol(secureRandom);
try {
// Perform the handshake in a non-blocking fashion
serverProtocol.accept(server, DtlsHandler.this);
// Prepare the shared key to be used in RTP streaming
server.prepareSrtpSharedSecret();
// Generate encoders for DTLS traffic
srtpDecoder = generateRtpDecoder();
srtpEncoder = generateRtpEncoder();
srtcpDecoder = generateRtcpDecoder();
srtcpEncoder = generateRtcpEncoder();
// Declare handshake as complete
handshakeComplete = true;
handshakeFailed = false;
handshaking = false;
// Warn listeners handshake completed
fireHandshakeComplete();
} catch (Exception e) {
logger.error("DTLS handshake failed. Reason:", e);
// Declare handshake as failed
handshakeComplete = false;
handshakeFailed = true;
handshaking = false;
// Warn listeners handshake completed
fireHandshakeFailed(e);
}
}
}
}