/* The MIT License (MIT)
* Copyright (c) 2015 YouView Ltd
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package com.ruesga.android.wallpapers.photophase.cast.mdsn;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.net.MulticastSocket;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.Map;
/**
* Low-level functionality for handling questions and answers over the Multicast DNS Protocol
* (RFC 6762), in conjunction with DNS Service Discovery (DNS-SD, RFC 6763).
*/
public class MDNSDiscover {
private static final short QTYPE_A = 0x0001;
static final short QTYPE_PTR = 0x000c;
static final short QTYPE_TXT = 0x0010;
static final short QTYPE_SRV = 0x0021;
static final short QCLASS_INTERNET = 0x0001;
@SuppressWarnings("unused")
static final short CLASS_FLAG_MULTICAST = 0, CLASS_FLAG_UNICAST = (short) 0x8000;
private static final int PORT = 5353;
private static final String MULTICAST_GROUP_ADDRESS = "224.0.0.251";
private static final boolean DEBUG = false;
/**
* @see #discover(String, Callback, int)
*/
public interface Callback {
void onResult(Result result);
}
private static byte[] discoverPacket(String serviceType) throws IOException {
return queryPacket(serviceType, QCLASS_INTERNET | CLASS_FLAG_UNICAST, QTYPE_PTR);
}
static byte[] queryPacket(String serviceName, int qclass, int... qtypes) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
dos.writeInt(0);
dos.writeShort(qtypes.length); // questions
dos.writeShort(0); // answers
dos.writeShort(0); // nscount
dos.writeShort(0); // arcount
int fqdnPtr = -1;
for (int qtype : qtypes) {
if (fqdnPtr == -1) {
fqdnPtr = dos.size();
writeFQDN(serviceName, dos);
} else {
// packet compression, string is just a pointer to previous occurrence
dos.write(0xc0 | (fqdnPtr >> 8));
dos.write(fqdnPtr & 0xFF);
}
dos.writeShort(qtype);
dos.writeShort(qclass);
}
dos.close();
return bos.toByteArray();
}
/**
* Sends a discovery packet for the specified service and listens for reply packets, notifying
* a callback as services are discovered.
* @param serviceType the type of service to query in mDNS, e.g. {@code "_example._tcp.local"}
* @param callback receives callbacks with {@link Result} objects as answers are decoded from
* incoming reply packets.
* @param timeout duration in milliseconds to wait for answer packets. If {@code 0}, this method
* will listen forever.
* @throws IOException
*/
@SuppressWarnings("unused")
public static void discover(String serviceType, Callback callback, int timeout) throws IOException {
if (timeout < 0) throw new IllegalArgumentException();
InetAddress group = InetAddress.getByName(MULTICAST_GROUP_ADDRESS);
MulticastSocket sock = new MulticastSocket(); // binds to a random free source port
if (DEBUG) System.out.println("Source port is " + sock.getLocalPort());
byte[] data = discoverPacket(serviceType);
if (DEBUG) System.out.println("Query packet:");
if (DEBUG) hexdump(data, 0, data.length);
DatagramPacket packet = new DatagramPacket(data, data.length, group, PORT);
sock.setTimeToLive(255);
sock.send(packet);
byte[] buf = new byte[1024];
packet = new DatagramPacket(buf, buf.length);
long endTime = 0;
if (timeout != 0) {
endTime = System.currentTimeMillis() + timeout;
}
while (true) {
if (timeout != 0) {
int remaining = (int) (endTime - System.currentTimeMillis());
if (remaining <= 0) {
break;
}
sock.setSoTimeout(remaining);
}
try {
sock.receive(packet);
} catch (SocketTimeoutException e) {
break;
}
if (DEBUG) System.out.println("\n\nIncoming packet:");
if (DEBUG) hexdump(packet.getData(), 0, packet.getLength());
Result result = decode(packet.getData(), packet.getLength());
if (callback != null) {
callback.onResult(result);
}
}
}
/**
* Ask for the A, SRV and TXT records of a particular service.
* @param serviceName the name of service to query in mDNS, e.g.
* {@code "device-1234._example._tcp.local"}
* @param timeout duration in milliseconds to wait for an answer packet. If {@code 0}, this
* method will listen forever.
* @return the reply packet's decoded answer data
* @throws IOException
*/
public static Result resolve(String serviceName, int timeout) throws IOException {
if (timeout < 0) throw new IllegalArgumentException();
InetAddress group = InetAddress.getByName(MULTICAST_GROUP_ADDRESS);
MulticastSocket sock = new MulticastSocket(); // binds to a random free source port
if (DEBUG) System.out.println("Source port is " + sock.getLocalPort());
if (DEBUG) System.out.println("Query packet:");
byte[] data = queryPacket(serviceName, QCLASS_INTERNET | CLASS_FLAG_UNICAST, QTYPE_A, QTYPE_SRV, QTYPE_TXT);
if (DEBUG) hexdump(data, 0, data.length);
DatagramPacket packet = new DatagramPacket(data, data.length, group, PORT);
sock.setTimeToLive(255);
sock.send(packet);
byte[] buf = new byte[1024];
packet = new DatagramPacket(buf, buf.length);
Result result = new Result();
long endTime = 0;
if (timeout != 0) {
endTime = System.currentTimeMillis() + timeout;
}
// records could be returned in different packets, so we have to loop
// timeout applies to the acquisition of ALL packets
while (result.a == null || result.srv == null || result.txt == null) {
if (timeout != 0) {
int remaining = (int) (endTime - System.currentTimeMillis());
if (remaining <= 0) {
break;
}
sock.setSoTimeout(remaining);
}
sock.receive(packet);
if (DEBUG) System.out.println("\n\nIncoming packet:");
if (DEBUG) hexdump(packet.getData(), 0, packet.getLength());
decode(packet.getData(), packet.getLength(), result);
}
return result;
}
private static void writeFQDN(String name, OutputStream out) throws IOException {
for (String part : name.split("\\.")) {
out.write(part.length());
out.write(part.getBytes());
}
out.write(0);
}
private static void hexdump(byte[] data, int offset, int length) {
while (offset < length) {
System.out.printf("%08x", offset);
int origOffset = offset;
int col;
for (col = 0; col < 16 && offset < length; col++, offset++) {
System.out.printf(" %02x", data[offset] & 0xFF);
}
for (; col < 16; col++) {
System.out.printf(" ");
}
System.out.print(" ");
offset = origOffset;
for (col = 0; col < 16 && offset < length; col++, offset++) {
byte val = data[offset];
char c;
if (val >= 32 && val < 127) {
c = (char) val;
} else {
c = '.';
}
System.out.printf("%c", c);
}
System.out.println();
}
}
public static class Record {
/** Fully-Qualified Domain Name of the record. */
public String fqdn;
/** Time-to-live of the record, in seconds. */
public int ttl;
}
/** DNS A record */
public static class A extends Record {
/** The IPv4 address in dot-decimal notation, e.g. {@code "192.168.1.100"} */
public String ipaddr;
}
public static class SRV extends Record {
public int priority, weight, port;
/** Fully-Qualified Domain Name of the target service. */
public String target;
}
public static class TXT extends Record {
/** The content of the TXT record's key-value store decoded as a {@link Map} */
public Map<String, String> dict;
}
/**
* Represents the decoded content of the answer sections of an incoming packet.
* When the corresponding data is present in an answer, fields will be initialized with
* populated data structures. When no such answer is present in the packet, fields will be
* {@code null}.
*/
public static class Result {
public A a;
public SRV srv;
public TXT txt;
}
static Result decode(byte[] packet, int packetLength) throws IOException {
Result result = new Result();
decode(packet, packetLength, result);
return result;
}
@SuppressWarnings("unused")
static void decode(byte[] packet, int packetLength, Result result) throws IOException {
DataInputStream dis = new DataInputStream(new ByteArrayInputStream(packet, 0, packetLength));
short transactionID = dis.readShort();
short flags = dis.readShort();
int questions = dis.readUnsignedShort();
int answers = dis.readUnsignedShort();
int authorityRRs = dis.readUnsignedShort();
int additionalRRs = dis.readUnsignedShort();
// decode the queries
for (int i = 0; i < questions; i++) {
String fqdn = decodeFQDN(dis, packet, packetLength);
short type = dis.readShort();
short qclass = dis.readShort();
}
// decode the answers
for (int i = 0; i < answers + authorityRRs + additionalRRs; i++) {
String fqdn = decodeFQDN(dis, packet, packetLength);
short type = dis.readShort();
short aclass = dis.readShort();
if (DEBUG) System.out.printf("%s record%n", typeString(type));
if (DEBUG) System.out.println("Name: " + fqdn);
int ttl = dis.readInt();
int length = dis.readUnsignedShort();
byte[] data = new byte[length];
dis.readFully(data);
Record record = null;
switch (type) {
case QTYPE_A:
record = result.a = decodeA(data);
break;
case QTYPE_SRV:
record = result.srv = decodeSRV(data, packet, packetLength);
break;
case QTYPE_PTR:
decodePTR(data, packet, packetLength);
break;
case QTYPE_TXT:
record = result.txt = decodeTXT(data);
break;
default:
if (DEBUG) hexdump(data, 0, data.length);
break;
}
if (record != null) {
record.fqdn = fqdn;
record.ttl = ttl;
}
}
}
private static SRV decodeSRV(byte[] srvData, byte[] packetData, int packetLength) throws IOException {
DataInputStream dis = new DataInputStream(new ByteArrayInputStream(srvData));
SRV srv = new SRV();
srv.priority = dis.readUnsignedShort();
srv.weight = dis.readUnsignedShort();
srv.port = dis.readUnsignedShort();
srv.target = decodeFQDN(dis, packetData, packetLength);
if (DEBUG) System.out.printf("Priority: %d Weight: %d Port: %d Target: %s%n", srv.priority, srv.weight, srv.port, srv.target);
return srv;
}
private static String typeString(short type) {
switch (type) {
case QTYPE_A:
return "A";
case QTYPE_PTR:
return "PTR";
case QTYPE_SRV:
return "SRV";
case QTYPE_TXT:
return "TXT";
default:
return "Unknown";
}
}
private static String decodePTR(byte[] ptrData, byte[] packet, int packetLength) throws IOException {
DataInputStream dis = new DataInputStream(new ByteArrayInputStream(ptrData));
String fqdn = decodeFQDN(dis, packet, packetLength);
if (DEBUG) System.out.println(fqdn);
return fqdn;
}
private static A decodeA(byte[] data) throws IOException {
if (data.length < 4) throw new IOException("expected 4 bytes for IPv4 addr");
A a = new A();
a.ipaddr = (data[0] & 0xFF) + "." + (data[1] & 0xFF) + "." + (data[2] & 0xFF) + "." + (data[3] & 0xFF);
if (DEBUG) System.out.println("Ipaddr: " + a.ipaddr);
return a;
}
private static TXT decodeTXT(byte[] data) throws IOException {
TXT txt = new TXT();
txt.dict = new HashMap<>();
DataInputStream dis = new DataInputStream(new ByteArrayInputStream(data));
while (true) {
int length;
try {
length = dis.readUnsignedByte();
} catch (EOFException e) {
return txt;
}
byte[] segmentBytes = new byte[length];
dis.readFully(segmentBytes);
String segment = new String(segmentBytes);
int pos = segment.indexOf('=');
String key, value = null;
if (pos != -1) {
key = segment.substring(0, pos);
value = segment.substring(pos + 1);
} else {
key = segment;
}
if (DEBUG) System.out.println(key + "=" + value);
if (!txt.dict.containsKey(key)) {
// from RFC6763
// If a client receives a TXT record containing the same key more than once, then
// the client MUST silently ignore all but the first occurrence of that attribute."
txt.dict.put(key, value);
}
}
}
private static String decodeFQDN(DataInputStream dis, byte[] packet, int packetLength) throws IOException {
StringBuilder result = new StringBuilder();
boolean dot = false;
while (true) {
int pointerHopCount = 0;
int length;
while (true) {
length = dis.readUnsignedByte();
if (length == 0) return result.toString();
if ((length & 0xc0) == 0xc0) {
// this is a compression method, the remainder of the string is a pointer to elsewhere in the packet
// adjust the stream boundary and repeat processing
if ((++pointerHopCount) * 2 >= packetLength) {
// We must have visited one of the possible pointers more than once => cycle
// this doesn't add to the domain length, but decoding would be non-terminating
throw new IOException("cyclic empty references in domain name");
}
length &= 0x3f;
int offset = (length << 8) | dis.readUnsignedByte();
dis = new DataInputStream(new ByteArrayInputStream(packet, offset, packetLength - offset));
} else {
break;
}
}
byte[] segment = new byte[length];
dis.readFully(segment);
if (dot) result.append('.');
dot = true;
result.append(new String(segment));
if (result.length() > packetLength) {
// If we get here, we must be following cyclic references, since non-cyclic
// references can't encode a domain name longer than the total length of the packet.
// The domain name would be infinitely long, so abort now rather than consume
// maximum heap.
throw new IOException("cyclic non-empty references in domain name");
}
}
}
}