package javaforce; /** STUN (w/ TURN support) client. * * Created : Nov 25, 2013 * * See RFCs: * http://tools.ietf.org/html/rfc3489 - Classic STUN * http://tools.ietf.org/html/rfc5389 - STUN * http://tools.ietf.org/html/rfc5766 - TURN * * I can't stand reading RFCs. Built this mostly with X-Lite and Wireshark as usual. * Thanks to resiprocate open-source project, would have NEVER figured out the HmacSHA1 stuff. */ import java.net.*; import java.nio.*; import java.util.*; import javax.crypto.*; import javax.crypto.spec.*; public class STUN { public interface Listener { public void stunPublicIP(STUN stun, String ip, int port); public void turnAlloc(STUN stun, String ip, int port, byte token[], int lifetime); public void turnBind(STUN stun); public void turnRefresh(STUN stun, int lifetime); public void turnFailed(STUN stun); public void turnData(STUN stun, byte data[], int offset, int length, short channel); } private DatagramSocket ds; private InetAddress addr; private int StunPort = 3478; //requests private final static short BINDING_REQUEST = 0x0001; private final static short ALLOCATE_REQUEST = 0x0003; private final static short REFRESH_REQUEST = 0x0004; private final static short BIND_REQUEST = 0x0009; //responses private final static short BINDING_RESPONSE = 0x0101; private final static short ALLOCATE_RESPONSE = 0x0103; private final static short REFRESH_RESPONSE = 0x0104; private final static short BIND_RESPONSE = 0x0109; //attrs private final static short MAPPED_ADDRESS = 0x0001; private final static short CHANGE_REQUEST = 0x0003; private final static short USERNAME = 0x0006; private final static short MESSAGE_INTEGRITY = 0x0008; private final static short ERROR_CODE = 0x0009; private final static short CHANNEL_NUMBER = 0x000c; private final static short LIFETIME = 0x000d; private final static short XOR_PEER_ADDRESS = 0x0012; private final static short REALM = 0x0014; private final static short NONCE = 0x0015; private final static short XOR_RELAY_ADDRESS = 0x0016; private final static short DATA_INDICATION = 0x0017; private final static short EVEN_PORT = 0x0018; private final static short TRANSPORT_TYPE = 0x0019; private final static short XOR_MAPPED_ADDRESS = 0x0020; private final static short RESERVATION_TOKEN = 0x0022; private long id1, id2; private Listener listener; private boolean active = true; private String user, pass; private boolean sentAuth = false; private String realm, nonce; private boolean evenPort; private byte token[]; private byte fulldata[]; private ByteBuffer fulldatabb; private int relayPort = -1; private String relayIP = null; private int lifetime = -1; private short lastRequest = -1; /** Connects to STUN server and starts socket listening thread. * @param localport = localport to listen on (-1 = any) */ public boolean start(int localport, String host, String user, String pass, Listener listener) { this.listener = listener; this.user = user; this.pass = pass; try { int idx = host.indexOf(":"); if (idx != -1) { StunPort = JF.atoi(host.substring(idx+1)); host = host.substring(0, idx); } addr = InetAddress.getByName(host); if (localport == -1) { ds = new DatagramSocket(); } else { ds = new DatagramSocket(localport); } new Worker().start(); return true; } catch (Exception e) { JFLog.log(e); return false; } } public int getLocalPort() { return ds.getLocalPort(); } public String getLocalAddr() { return ds.getLocalAddress().getHostAddress(); } public void close() { active = false; if (ds == null) return; try { ds.close(); ds = null; } catch (Exception e) { JFLog.log(e); } } private void genID() { Random r = new Random(); id1 = 0x2112a442; //magic cookie id1 <<= 32; id1 += Math.abs(r.nextInt()); id2 = r.nextLong(); } /** STUN : Request real Public IP */ public void requestPublicIP() { requestPublicIP(false, false); } /** STUN : Request real Public IP from a different IP and or Port */ public void requestPublicIP(boolean change_ip, boolean change_port) { try { lastRequest = BINDING_REQUEST; int packetSize = 20 + 8; byte request[] = new byte[packetSize]; DatagramPacket dp = new DatagramPacket(request, packetSize); ByteBuffer bb = ByteBuffer.wrap(request); bb.order(ByteOrder.BIG_ENDIAN); int offset = 0; bb.putShort(offset, BINDING_REQUEST); offset += 2; bb.putShort(offset, (short)8); //length offset += 2; genID(); bb.putLong(offset, id1); offset += 8; bb.putLong(offset, id2); offset += 8; bb.putShort(offset, CHANGE_REQUEST); offset += 2; bb.putShort(offset, (short)4); //length offset += 2; bb.putInt(offset, (change_ip ? 0x04 : 0) + (change_port ? 0x02: 0)); //flags offset += 4; dp.setAddress(addr); dp.setPort(StunPort); ds.send(dp); } catch (Exception e) { JFLog.log(e); } } /** TURN : Request a UDP data channel. * * Token is optional, used if last request was for even port (to alloc the odd port). */ public void requestAlloc(boolean evenPort, byte token[]) { this.evenPort = evenPort; this.token = token; int lengthOffset; try { lastRequest = ALLOCATE_REQUEST; byte request[] = new byte[1024]; ByteBuffer bb = ByteBuffer.wrap(request); bb.order(ByteOrder.BIG_ENDIAN); int offset = 0; bb.putShort(offset, ALLOCATE_REQUEST); offset += 2; lengthOffset = offset; bb.putShort(offset, (short)0); //(patch later) length offset += 2; genID(); bb.putLong(offset, id1); offset += 8; bb.putLong(offset, id2); offset += 8; if (evenPort) { bb.putShort(offset, EVEN_PORT); offset += 2; bb.putShort(offset, (short)1); //length offset += 2; bb.put(offset, (byte)0x80); //reserve_next=0x80 offset += 1; offset += 3; //padding } bb.putShort(offset, TRANSPORT_TYPE); offset += 2; bb.putShort(offset, (short)4); //length offset += 2; bb.put(offset, (byte)0x11); //UDP offset += 1; offset += 3; //padding if (realm != null && nonce != null) { int strlen = user.length(); bb.putShort(offset, USERNAME); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(user.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (realm != null) { int strlen = realm.length(); bb.putShort(offset, REALM); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(realm.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (nonce != null) { int strlen = nonce.length(); bb.putShort(offset, NONCE); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(nonce.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (token != null) { int strlen = token.length; bb.putShort(offset, RESERVATION_TOKEN); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(token, 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } //message integrity if (realm != null && nonce != null) { //length should include size of message integrity attr (even though it's not filled in yet) bb.putShort(lengthOffset, (short)(offset - 20 + 24)); //patch length byte id[] = calcMsgIntegrity(request, offset, calcKey(user, realm, pass)); int strlen = id.length; bb.putShort(offset, MESSAGE_INTEGRITY); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(id, 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } bb.putShort(lengthOffset, (short)(offset - 20)); //patch length DatagramPacket dp = new DatagramPacket(request, offset); dp.setAddress(addr); dp.setPort(StunPort); ds.send(dp); } catch (Exception e) { JFLog.log(e); } } /** TURN : Bind to host:port (also a keep alive) * @param channel : 0x4000 thru 0x7ffe (you pick one by random I guess) */ public void requestBind(short channel, String host, int port) { int lengthOffset; try { lastRequest = BIND_REQUEST; byte hostaddr[] = InetAddress.getByName(host).getAddress(); byte request[] = new byte[1024]; ByteBuffer bb = ByteBuffer.wrap(request); bb.order(ByteOrder.BIG_ENDIAN); int offset = 0; bb.putShort(offset, BIND_REQUEST); offset += 2; lengthOffset = offset; bb.putShort(offset, (short)0); //length (patch later) offset += 2; genID(); bb.putLong(offset, id1); offset += 8; bb.putLong(offset, id2); offset += 8; if (realm != null && nonce != null) { int strlen = user.length(); bb.putShort(offset, USERNAME); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(user.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (realm != null) { int strlen = realm.length(); bb.putShort(offset, REALM); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(realm.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (nonce != null) { int strlen = nonce.length(); bb.putShort(offset, NONCE); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(nonce.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } bb.putShort(offset, CHANNEL_NUMBER); offset += 2; bb.putShort(offset, (short)4); offset += 2; bb.putShort(offset, channel); offset += 2; offset += 2; //reserved??? bb.putShort(offset, XOR_PEER_ADDRESS); offset += 2; bb.putShort(offset, (short)8); offset += 2; offset++; //reserved bb.put(offset, (byte)0x01); //IP4 offset++; bb.putShort(offset, (short)(port ^ bb.getShort(4))); offset += 2; for(int a=0;a<4;a++) { request[offset++] = (byte)(hostaddr[a] ^ request[4 + a]); } //message integrity if (realm != null && nonce != null) { //length should include size of message integrity attr (even though it's not filled in yet) bb.putShort(lengthOffset, (short)(offset - 20 + 24)); //patch length byte id[] = calcMsgIntegrity(request, offset, calcKey(user, realm, pass)); int strlen = id.length; bb.putShort(offset, MESSAGE_INTEGRITY); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(id, 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } bb.putShort(lengthOffset, (short)(offset - 20)); //patch length DatagramPacket dp = new DatagramPacket(request, offset); dp.setAddress(addr); dp.setPort(StunPort); ds.send(dp); } catch (Exception e) { JFLog.log(e); } } /** TURN : Refresh a connection (keep alive) */ public void requestRefresh(int seconds) { int lengthOffset; try { lastRequest = REFRESH_REQUEST; byte request[] = new byte[1024]; ByteBuffer bb = ByteBuffer.wrap(request); bb.order(ByteOrder.BIG_ENDIAN); int offset = 0; bb.putShort(offset, REFRESH_REQUEST); offset += 2; lengthOffset = offset; bb.putShort(offset, (short)0); //length (patch later) offset += 2; genID(); bb.putLong(offset, id1); offset += 8; bb.putLong(offset, id2); offset += 8; if (realm != null && nonce != null) { int strlen = user.length(); bb.putShort(offset, USERNAME); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(user.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (realm != null) { int strlen = realm.length(); bb.putShort(offset, REALM); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(realm.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } if (nonce != null) { int strlen = nonce.length(); bb.putShort(offset, NONCE); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(nonce.getBytes(), 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } bb.putShort(offset, LIFETIME); offset += 2; bb.putShort(offset, (short)4); offset += 2; bb.putInt(offset, seconds); offset += 4; //message integrity if (realm != null && nonce != null) { //length should include size of message integrity attr (even though it's not filled in yet) bb.putShort(lengthOffset, (short)(offset - 20 + 24)); //patch length byte id[] = calcMsgIntegrity(request, offset, calcKey(user, realm, pass)); int strlen = id.length; bb.putShort(offset, MESSAGE_INTEGRITY); offset += 2; bb.putShort(offset, (short)strlen); offset += 2; System.arraycopy(id, 0, request, offset, strlen); offset += strlen; if ((offset & 3) > 0) { offset += 4 - (offset & 3); //padding } } bb.putShort(lengthOffset, (short)(offset - 20)); //patch length DatagramPacket dp = new DatagramPacket(request, offset); dp.setAddress(addr); dp.setPort(StunPort); ds.send(dp); } catch (Exception e) { JFLog.log(e); } } public String getIP() {return relayIP;} public int getPort() {return relayPort;} /** TURN : Send out a UDP packet. */ public void sendData(short channel, byte data[], int offset, int length) { if (fulldata == null || fulldata.length != length + 4) { fulldata = new byte[length + 4]; fulldatabb = ByteBuffer.wrap(fulldata); fulldatabb.order(ByteOrder.BIG_ENDIAN); } fulldatabb.putShort(0, channel); fulldatabb.putShort(2, (short)length); System.arraycopy(data, offset, fulldata, 4, length); DatagramPacket dp = new DatagramPacket(fulldata, length + 4); try { dp.setAddress(addr); dp.setPort(StunPort); ds.send(dp); } catch (Exception e) { JFLog.log(e); } } //for short-term credentials public static byte[] calcKey(String pass) { return pass.getBytes(); } //for long-term credentials public static byte[] calcKey(String user, String realm, String pass) { String msg = user + ":" + realm + ":" + pass; MD5 md5 = new MD5(); md5.init(); md5.add(msg.getBytes(), 0, msg.length()); return md5.done(); } public static byte[] calcMsgIntegrity(byte data[], int length, byte key[]) { try { SecretKeySpec ks = new SecretKeySpec(key, "HmacSHA1"); Mac mac = Mac.getInstance("HmacSHA1"); mac.init(ks); return mac.doFinal(Arrays.copyOfRange(data, 0, length)); } catch (Exception e) { JFLog.log(e); return null; } } //see http://tools.ietf.org/html/rfc5389#section-15.5 public static int calcFingerprint(byte data[], int length) { java.util.zip.CRC32 crc = new java.util.zip.CRC32(); crc.update(data, 0, length); return ((int)crc.getValue()) ^ 0x5354554e; } private class Worker extends Thread { public void run() { DatagramPacket dp; boolean resendAuth; int errcode; int ip[], port; byte response[] = new byte[1500]; ByteBuffer bb = ByteBuffer.wrap(response); bb.order(ByteOrder.BIG_ENDIAN); while (active) { try { resendAuth = false; dp = new DatagramPacket(response, 1500); ds.receive(dp); //TODO : validate packet source int packetLength = dp.getLength(); //decode response int offset = 0; short code = bb.getShort(0); if (code >= 0x4000) { //it's TURN data received back listener.turnData(STUN.this, response, 4, bb.getShort(2), code); continue; } // JFLog.log("STUN:code=0x" + Integer.toString(code, 16)); if (code == BIND_RESPONSE) { listener.turnBind(STUN.this); } if (code == DATA_INDICATION) { continue; } offset += 2; short length = bb.getShort(offset); if (length + 20 != packetLength) { throw new Exception("STUN:bad packet:incorrect length"); } offset += 2; long _id1 = bb.getLong(offset); offset += 8; long _id2 = bb.getLong(offset); offset += 8; if (id1 != _id1 || id2 != _id2) { throw new Exception("STUN:bad packet:id mismatch"); } while (offset < packetLength) { short attr = bb.getShort(offset); offset += 2; length = bb.getShort(offset); offset += 2; switch (attr) { case MAPPED_ADDRESS: //CLASSIC STUN port = ((int)bb.getShort(offset + 2)) & 0xffff; ip = new int[4]; for(int a=0;a<4;a++) { ip[a] = ((int)response[offset + 4 + a]) & 0xff; } if (code == BINDING_RESPONSE) { listener.stunPublicIP(STUN.this, String.format("%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3]), port); } break; case XOR_MAPPED_ADDRESS: //NEW STUN port = (bb.getShort(offset + 2) ^ bb.getShort(4)) & 0xffff; ip = new int[4]; for(int a=0;a<4;a++) { ip[a] = (response[offset + 4 + a] ^ response[4 + a]) & 0xff; } if (code == BINDING_RESPONSE) { listener.stunPublicIP(STUN.this, String.format("%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3]), port); } break; case REALM: realm = new String(response, offset, length); JFLog.log("STUN:realm=" + realm); break; case NONCE: nonce = new String(response, offset, length); JFLog.log("STUN:nonce=" + nonce); break; case ERROR_CODE: errcode = bb.getShort(offset + 2); switch (errcode) { case 0x401: if (sentAuth) { listener.turnFailed(STUN.this); JFLog.log("STUN:Error:" + Integer.toString(errcode, 16) + " (Bad Auth)"); } else { resendAuth = true; } break; default: JFLog.log("STUN:Error:" + Integer.toString(errcode, 16)); break; } break; case XOR_RELAY_ADDRESS: relayPort = (bb.getShort(offset + 2) ^ bb.getShort(4)) & 0xffff; ip = new int[4]; for(int a=0;a<4;a++) { ip[a] = (response[offset + 4 + a] ^ response[4 + a]) & 0xff; } relayIP = String.format("%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3]); if (relayIP.equals("0.0.0.0")) { //use turn host address byte ip4[] = addr.getAddress(); relayIP = String.format("%d.%d.%d.%d", ip4[0], ip4[1], ip4[2], ip4[3]);; } break; case RESERVATION_TOKEN: token = new byte[length]; System.arraycopy(response, offset, token, 0, length); break; case LIFETIME: lifetime = bb.getInt(offset); break; } offset += length; if ((length & 0x3) > 0) { offset += 4 - (length & 0x3); //padding } } if (resendAuth) { if (lastRequest == ALLOCATE_REQUEST) { //resend alloc request with auth sentAuth = true; requestAlloc(evenPort, token); } } if (code == ALLOCATE_RESPONSE) { listener.turnAlloc(STUN.this, relayIP, relayPort, token, lifetime); token = null; } if (code == REFRESH_RESPONSE) { listener.turnRefresh(STUN.this, lifetime); } } catch (Exception e) { if (active) JFLog.log(e); } } } } public enum NAT {Unknown, None, FullCone, RestrictedCone, RestrictedPort, SymmetricFirewall, SymmetricNAT}; public static class Test implements Listener { private volatile String ip; private volatile int port; private volatile boolean ok; public NAT run(int localport, String host1, String host2) { boolean t1, t2, t3, t1b; STUN stun = new STUN(); stun.start(localport, host1, null, null, this); ok = false; stun.requestPublicIP(false, false); JF.sleep(1000); if (!ok) { t1 = false; JFLog.log("STUN:Test I:Failed"); } else { t1 = true; JFLog.log("STUN:Test I:IP=" + ip + ":" + port); } ok = false; stun.requestPublicIP(true, true); JF.sleep(1000); if (!ok) { t2 = false; JFLog.log("STUN:Test II:Failed"); } else { t2 = true; JFLog.log("STUN:Test II:IP=" + ip + ":" + port); } ok = false; stun.requestPublicIP(false, true); JF.sleep(1000); if (!ok) { t3 = false; JFLog.log("STUN:Test III:Failed"); } else { t3 = true; JFLog.log("STUN:Test III:IP=" + ip + ":" + port); } String localIP = stun.getLocalAddr(); stun.close(); String ip1 = ip; int port1 = port; if (host2 != null) { stun = new STUN(); stun.start(localport, host2, null, null, this); ok = false; stun.requestPublicIP(false, false); JF.sleep(1000); if (!ok) { t1b = false; JFLog.log("STUN:Test I(Server #2):Failed"); } else { t1b = true; JFLog.log("STUN:Test I(Server #2):IP=" + ip + ":" + port); } stun.close(); } else { t1b = false; } JFLog.log("STUN:Tests Complete"); if (!t1) return NAT.Unknown; if (localIP.equals(ip1)) { //no NAT if (t2) return NAT.None; else return NAT.SymmetricFirewall; } if (t2) return NAT.FullCone; if (t1b) { if (!ip1.equals(ip) || port1 != port) { return NAT.SymmetricNAT; } } else { JFLog.log("STUN:Test:Warning:2nd STUN server failed or skipped, Symmetric NAT test undetermined."); } if (t3) return NAT.RestrictedCone; else return NAT.RestrictedPort; } public void stunPublicIP(STUN stun, String ip, int port) { this.ip = ip; this.port = port; ok = true; }; public void turnAlloc(STUN stun, String ip, int port, byte token[], int lifetime) {}; public void turnBind(STUN stun) {}; public void turnRefresh(STUN stun, int lifetime) {}; public void turnFailed(STUN stun) {}; public void turnData(STUN stun, byte data[], int offset, int length, short channel) {}; } /** Performs a quick test to determine your firewall type. */ public static NAT doTest(int port, String host1, String host2) { return new Test().run(port, host1, host2); } public static void main(String args[]) { if (args.length < 2) { System.out.println("Desc: Determine your Firewall NAT type."); System.out.println("Usage: javaforce.STUN port server1 [server2]"); System.out.println("Two servers are recommended to detect Symmetric router."); } else { int port = (int)Integer.valueOf(args[0]); String s1 = args[1]; String s2; if (args.length > 2) s2 = args[2]; else s2 = null; System.out.println("Result=" + new Test().run(port, s1, s2)); } } }