/*******************************************************************************
* This file is part of OpenNMS(R).
*
* Copyright (C) 2011 The OpenNMS Group, Inc.
* OpenNMS(R) is Copyright (C) 1999-2011 The OpenNMS Group, Inc.
*
* OpenNMS(R) is a registered trademark of The OpenNMS Group, Inc.
*
* This file is a derivative work, containing both original code, included code,
* and modified code that was published under the GNU General Public License.
*
* Original code Copyright (c) 1999-2004 Brian Wellington (bwelling@xbill.org)
*
* Refactored from DNSServer in the JDNSS server
* http://sourceforge.net/projects/jdnss/
*
* Project site for JDNSS says "BSD and GPL license" but this file had no
* specifics about which license it's specifically under, so assume the more
* restrictive GPL until we can get more details.
*
* OpenNMS(R) is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published
* by the Free Software Foundation, either version 3 of the License,
* or (at your option) any later version.
*
* OpenNMS(R) is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with OpenNMS(R). If not, see:
* http://www.gnu.org/licenses/
*
* For more information contact:
* OpenNMS(R) Licensing <license@opennms.org>
* http://www.opennms.org/
* http://www.opennms.com/
*******************************************************************************/
package org.opennms.core.test.dns;
import java.io.*;
import java.net.*;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import org.opennms.core.utils.InetAddressUtils;
import org.opennms.core.utils.LogUtils;
import org.xbill.DNS.*;
public class DNSServer {
private static final int DEFAULT_SOCKET_TIMEOUT = 100;
private final class TCPListener implements Stoppable {
private final int m_port;
private final InetAddress m_addr;
private ServerSocket m_socket;
private volatile boolean m_stopped = false;
private CountDownLatch m_latch = new CountDownLatch(1);
private TCPListener(final int port, final InetAddress addr) {
m_port = port;
m_addr = addr;
}
public void run() {
try {
m_socket = new ServerSocket(m_port, 128, m_addr);
m_socket.setSoTimeout(DEFAULT_SOCKET_TIMEOUT);
while (!m_stopped) {
try {
final Socket s = m_socket.accept();
final Thread t = new Thread(new Runnable() {
public void run() {
try {
try {
final InputStream is = s.getInputStream();
final DataInputStream dataIn = new DataInputStream(is);
final int inLength = dataIn.readUnsignedShort();
final byte[] in = new byte[inLength];
dataIn.readFully(in);
final Message query;
byte[] response = null;
try {
query = new Message(in);
LogUtils.debugf(this, "received query: %s", query);
response = generateReply(query, in, in.length, s);
} catch (final IOException e) {
response = formerrMessage(in);
}
LogUtils.debugf(this, "returned response: %s", response == null? null : new Message(response));
if (response != null) {
final DataOutputStream dataOut = new DataOutputStream(s.getOutputStream());
dataOut.writeShort(response.length);
dataOut.write(response);
}
} catch (final SocketTimeoutException e) {
throw e;
} catch (final IOException e) {
LogUtils.warnf(this, e, "error while processing socket");
} finally {
try {
s.close();
} catch (final IOException e) {
LogUtils.warnf(this, e, "unable to close TCP socket");
}
}
} catch (final SocketTimeoutException e) {
if (LogUtils.isTraceEnabled(this)) {
LogUtils.tracef(this, e, "timed out waiting for request");
}
}
}
});
t.start();
} catch (final SocketTimeoutException e) {
if (LogUtils.isTraceEnabled(this)) {
LogUtils.tracef(this, e, "timed out waiting for request");
}
}
}
} catch (final IOException e) {
LogUtils.warnf(this, e, "unable to serve socket on %s", addrport(m_addr, m_port));
} finally {
try {
m_socket.close();
} catch (final IOException e) {
LogUtils.debugf(this, e, "error while closing socket");
}
m_latch.countDown();
}
}
public void stop() {
m_stopped = true;
try {
m_latch.await();
} catch (final InterruptedException e) {
LogUtils.warnf(this, e, "interrupted while stopping TCP listener");
Thread.currentThread().interrupt();
}
}
}
private final class UDPListener implements Stoppable {
private final int m_port;
private final InetAddress m_addr;
private volatile boolean m_stopped = false;
private CountDownLatch m_latch = new CountDownLatch(1);
private UDPListener(int port, InetAddress addr) {
m_port = port;
m_addr = addr;
}
public void run() {
DatagramSocket sock = null;
try {
sock = new DatagramSocket(m_port, m_addr);
sock.setSoTimeout(DEFAULT_SOCKET_TIMEOUT);
final short udpLength = 512;
byte[] in = new byte[udpLength];
final DatagramPacket indp = new DatagramPacket(in, in.length);
DatagramPacket outdp = null;
while (!m_stopped) {
indp.setLength(in.length);
try {
sock.receive(indp);
} catch (final InterruptedIOException e) {
continue;
}
final Message query;
byte[] response = null;
try {
query = new Message(in);
response = generateReply(query, in, indp.getLength(), null);
if (response == null)
continue;
} catch (final IOException e) {
response = formerrMessage(in);
}
if (outdp == null)
outdp = new DatagramPacket(response, response.length, indp.getAddress(), indp.getPort());
else {
outdp.setData(response);
outdp.setLength(response.length);
outdp.setAddress(indp.getAddress());
outdp.setPort(indp.getPort());
}
sock.send(outdp);
}
} catch (final IOException e) {
LogUtils.warnf(this, e, "error in the UDP listener: %s", addrport(m_addr, m_port));
} finally {
if (sock != null) {
try {
sock.close();
} catch (final Exception e) {
LogUtils.debugf(this, e, "error while closing socket");
}
}
m_latch.countDown();
}
}
public void stop() {
m_stopped = true;
try {
m_latch.await();
} catch (final InterruptedException e) {
LogUtils.warnf(this, e, "interrupted while waiting for server to stop");
Thread.currentThread().interrupt();
}
}
}
static final int FLAG_DNSSECOK = 1;
static final int FLAG_SIGONLY = 2;
final Map<Integer, Cache> m_caches = new HashMap<Integer, Cache>();
final Map<Name, Zone> m_znames = new HashMap<Name, Zone>();
final Map<Name, TSIG> m_TSIGs = new HashMap<Name, TSIG>();
final List<Integer> m_ports = new ArrayList<Integer>();
final List<InetAddress> m_addresses = new ArrayList<InetAddress>();
final List<Stoppable> m_activeListeners = new ArrayList<Stoppable>();
private static String addrport(final InetAddress addr, final int port) {
return InetAddressUtils.str(addr) + "#" + port;
}
public DNSServer(final String conffile) throws IOException, ZoneTransferException, ConfigurationException {
parseConfiguration(conffile);
}
public DNSServer() throws UnknownHostException {
}
public void start() throws UnknownHostException {
initializeDefaults();
for (final InetAddress addr : m_addresses) {
for (final Integer port : m_ports) {
final UDPListener udpListener = new UDPListener(port, addr);
final Thread udpThread = new Thread(udpListener);
udpThread.start();
m_activeListeners.add(udpListener);
final TCPListener tcpListener = new TCPListener(port, addr);
final Thread tcpThread = new Thread(tcpListener);
tcpThread.start();
m_activeListeners.add(tcpListener);
LogUtils.infof(this, "listening on %s", addrport(addr, port));
}
}
LogUtils.debugf(this, "finished starting up");
}
public void stop() {
for (final Stoppable listener : m_activeListeners) {
LogUtils.debugf(this, "stopping %s", listener);
listener.stop();
LogUtils.debugf(this, "stopped %s", listener);
}
}
protected void parseConfiguration(final String conffile) throws ConfigurationException, IOException,
ZoneTransferException, UnknownHostException {
final FileInputStream fs;
final InputStreamReader isr;
final BufferedReader br;
try {
fs = new FileInputStream(conffile);
isr = new InputStreamReader(fs);
br = new BufferedReader(isr);
} catch (final Exception e) {
LogUtils.errorf(this, e, "Cannot open %s", conffile);
throw new ConfigurationException("unable to read from " + conffile, e);
}
try {
String line = null;
while ((line = br.readLine()) != null) {
final StringTokenizer st = new StringTokenizer(line);
if (!st.hasMoreTokens()) {
continue;
}
final String keyword = st.nextToken();
if (!st.hasMoreTokens()) {
LogUtils.warnf(this, "unable to parse line: %s", line);
continue;
}
if (keyword.charAt(0) == '#') {
continue;
}
if (keyword.equals("primary")) {
addPrimaryZone(st.nextToken(), st.nextToken());
} else if (keyword.equals("secondary")) {
addSecondaryZone(st.nextToken(), st.nextToken());
} else if (keyword.equals("cache")) {
final Cache cache = new Cache(st.nextToken());
m_caches.put(new Integer(DClass.IN), cache);
} else if (keyword.equals("key")) {
final String s1 = st.nextToken();
final String s2 = st.nextToken();
if (st.hasMoreTokens()) {
addTSIG(s1, s2, st.nextToken());
} else {
addTSIG("hmac-md5", s1, s2);
}
} else if (keyword.equals("port")) {
m_ports.add(Integer.valueOf(st.nextToken()));
} else if (keyword.equals("address")) {
final String addr = st.nextToken();
m_addresses.add(Address.getByAddress(addr));
} else {
LogUtils.warnf(this, "unknown keyword: %s", keyword);
}
}
} finally {
fs.close();
}
}
protected void initializeDefaults() throws UnknownHostException {
if (m_ports.size() == 0) {
m_ports.add(new Integer(53));
}
if (m_addresses.size() == 0) {
m_addresses.add(Address.getByAddress("0.0.0.0"));
}
}
public void addPort(final int port) {
m_ports.add(port);
}
public void setPorts(final List<Integer> ports) {
m_ports.clear();
m_ports.addAll(ports);
}
public void addAddress(final InetAddress address) {
m_addresses.add(address);
}
public void setAddresses(final List<InetAddress> addresses) {
m_addresses.clear();
m_addresses.addAll(addresses);
}
public void addZone(final Zone zone) {
m_znames.put(zone.getOrigin(), zone);
}
public void addPrimaryZone(final String zname, final String zonefile) throws IOException {
Name origin = null;
if (zname != null)
origin = Name.fromString(zname, Name.root);
final Zone newzone = new Zone(origin, zonefile);
m_znames.put(newzone.getOrigin(), newzone);
}
public void addSecondaryZone(final String zone, final String remote) throws IOException, ZoneTransferException {
final Name zname = Name.fromString(zone, Name.root);
final Zone newzone = new Zone(zname, DClass.IN, remote);
m_znames.put(zname, newzone);
}
public void addTSIG(final String algstr, final String namestr, final String key) throws IOException {
final Name name = Name.fromString(namestr, Name.root);
m_TSIGs.put(name, new TSIG(algstr, namestr, key));
}
public Cache getCache(final int dclass) {
Cache c = m_caches.get(dclass);
if (c == null) {
c = new Cache(dclass);
m_caches.put(new Integer(dclass), c);
}
return c;
}
public Zone findBestZone(final Name name) {
Zone foundzone = m_znames.get(name);
if (foundzone != null) {
return foundzone;
}
final int labels = name.labels();
for (int i = 1; i < labels; i++) {
final Name tname = new Name(name, i);
foundzone = m_znames.get(tname);
if (foundzone != null) {
return foundzone;
}
}
return null;
}
public RRset findExactMatch(final Name name, final int type, final int dclass, final boolean glue) {
final Zone zone = findBestZone(name);
if (zone != null) {
return zone.findExactMatch(name, type);
} else {
final RRset[] rrsets;
final Cache cache = getCache(dclass);
if (glue) {
rrsets = cache.findAnyRecords(name, type);
} else {
rrsets = cache.findRecords(name, type);
}
if (rrsets == null) {
return null;
} else {
return rrsets[0]; /* not quite right */
}
}
}
void addRRset(final Name name, final Message response, final RRset rrset, final int section, final int flags) {
for (int s = 1; s <= section; s++) {
if (response.findRRset(name, rrset.getType(), s)) return;
}
if ((flags & FLAG_SIGONLY) == 0) {
@SuppressWarnings("unchecked")
final Iterator<Record> it = rrset.rrs();
while (it.hasNext()) {
final Record r = it.next();
if (r.getName().isWild() && !name.isWild()) {
response.addRecord(r.withName(name), section);
} else {
response.addRecord(r, section);
}
}
}
if ((flags & (FLAG_SIGONLY | FLAG_DNSSECOK)) != 0) {
@SuppressWarnings("unchecked")
final Iterator<Record> it = rrset.sigs();
while (it.hasNext()) {
final Record r = it.next();
if (r.getName().isWild() && !name.isWild()) {
response.addRecord(r.withName(name), section);
} else {
response.addRecord(r, section);
}
}
}
}
private final void addSOA(final Message response, final Zone zone) {
response.addRecord(zone.getSOA(), Section.AUTHORITY);
}
private final void addNS(final Message response, final Zone zone, final int flags) {
final RRset nsRecords = zone.getNS();
addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
}
private final void addCacheNS(final Message response, final Cache cache, final Name name) {
final SetResponse sr = cache.lookupRecords(name, Type.NS, Credibility.HINT);
if (!sr.isDelegation()) return;
final RRset nsRecords = sr.getNS();
@SuppressWarnings("unchecked")
final Iterator<Record> it = nsRecords.rrs();
while (it.hasNext()) {
final Record r = it.next();
response.addRecord(r, Section.AUTHORITY);
}
}
private void addGlue(final Message response, final Name name, final int flags) {
final RRset a = findExactMatch(name, Type.A, DClass.IN, true);
if (a == null) return;
addRRset(name, response, a, Section.ADDITIONAL, flags);
}
private void addAdditional2(final Message response, final int section, final int flags) {
final Record[] records = response.getSectionArray(section);
for (int i = 0; i < records.length; i++) {
final Record r = records[i];
final Name glueName = r.getAdditionalName();
if (glueName != null) addGlue(response, glueName, flags);
}
}
private final void addAdditional(final Message response, final int flags) {
addAdditional2(response, Section.ANSWER, flags);
addAdditional2(response, Section.AUTHORITY, flags);
}
byte addAnswer(final Message response, final Name name, int type, int dclass, int iterations, int flags) {
SetResponse sr;
byte rcode = Rcode.NOERROR;
if (iterations > 6)
return Rcode.NOERROR;
if (type == Type.SIG || type == Type.RRSIG) {
type = Type.ANY;
flags |= FLAG_SIGONLY;
}
final Zone zone = findBestZone(name);
if (zone != null)
sr = zone.findRecords(name, type);
else {
sr = getCache(dclass).lookupRecords(name, type, Credibility.NORMAL);
}
if (sr.isUnknown()) {
addCacheNS(response, getCache(dclass), name);
}
if (sr.isNXDOMAIN()) {
response.getHeader().setRcode(Rcode.NXDOMAIN);
if (zone != null) {
addSOA(response, zone);
if (iterations == 0) response.getHeader().setFlag(Flags.AA);
}
rcode = Rcode.NXDOMAIN;
} else if (sr.isNXRRSET()) {
if (zone != null) {
addSOA(response, zone);
if (iterations == 0) response.getHeader().setFlag(Flags.AA);
}
} else if (sr.isDelegation()) {
final RRset nsRecords = sr.getNS();
addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
} else if (sr.isCNAME()) {
final CNAMERecord cname = sr.getCNAME();
addRRset(name, response, new RRset(cname), Section.ANSWER, flags);
if (zone != null && iterations == 0) response.getHeader().setFlag(Flags.AA);
rcode = addAnswer(response, cname.getTarget(), type, dclass, iterations + 1, flags);
} else if (sr.isDNAME()) {
final DNAMERecord dname = sr.getDNAME();
RRset rrset = new RRset(dname);
addRRset(name, response, rrset, Section.ANSWER, flags);
final Name newname;
try {
newname = name.fromDNAME(dname);
} catch (final NameTooLongException e) {
return Rcode.YXDOMAIN;
}
rrset = new RRset(new CNAMERecord(name, dclass, 0, newname));
addRRset(name, response, rrset, Section.ANSWER, flags);
if (zone != null && iterations == 0)
response.getHeader().setFlag(Flags.AA);
rcode = addAnswer(response, newname, type, dclass, iterations + 1, flags);
} else if (sr.isSuccessful()) {
final RRset[] rrsets = sr.answers();
for (int i = 0; i < rrsets.length; i++)
addRRset(name, response, rrsets[i], Section.ANSWER, flags);
if (zone != null) {
addNS(response, zone, flags);
if (iterations == 0)
response.getHeader().setFlag(Flags.AA);
} else
addCacheNS(response, getCache(dclass), name);
}
return rcode;
}
byte[] doAXFR(final Name name, final Message query, final TSIG tsig, TSIGRecord qtsig, final Socket s) {
final Zone zone = m_znames.get(name);
boolean first = true;
if (zone == null)
return errorMessage(query, Rcode.REFUSED);
@SuppressWarnings("unchecked")
final Iterator<RRset> it = zone.AXFR();
try {
final DataOutputStream dataOut = new DataOutputStream(s.getOutputStream());
int id = query.getHeader().getID();
while (it.hasNext()) {
final RRset rrset = it.next();
final Message response = new Message(id);
final Header header = response.getHeader();
header.setFlag(Flags.QR);
header.setFlag(Flags.AA);
addRRset(rrset.getName(), response, rrset, Section.ANSWER, FLAG_DNSSECOK);
if (tsig != null) {
tsig.applyStream(response, qtsig, first);
qtsig = response.getTSIG();
}
first = false;
final byte[] out = response.toWire();
dataOut.writeShort(out.length);
dataOut.write(out);
}
} catch (final IOException ex) {
LogUtils.warnf(this, ex, "AXFR failed");
}
try {
s.close();
} catch (final IOException ex) {
LogUtils.warnf(this, ex, "error closing socket");
}
return null;
}
/*
* Note: a null return value means that the caller doesn't need to do
* anything. Currently this only happens if this is an AXFR request over
* TCP.
*/
byte[] generateReply(final Message query, final byte[] in, final int length, final Socket s) throws IOException {
final Header header = query.getHeader();
int maxLength;
int flags = 0;
if (header.getFlag(Flags.QR))
return null;
if (header.getRcode() != Rcode.NOERROR)
return errorMessage(query, Rcode.FORMERR);
if (header.getOpcode() != Opcode.QUERY)
return errorMessage(query, Rcode.NOTIMP);
final Record queryRecord = query.getQuestion();
final TSIGRecord queryTSIG = query.getTSIG();
TSIG tsig = null;
if (queryTSIG != null) {
tsig = m_TSIGs.get(queryTSIG.getName());
if (tsig == null || tsig.verify(query, in, length, null) != Rcode.NOERROR)
return formerrMessage(in);
}
final OPTRecord queryOPT = query.getOPT();
if (s != null)
maxLength = 65535;
else if (queryOPT != null)
maxLength = Math.max(queryOPT.getPayloadSize(), 512);
else
maxLength = 512;
if (queryOPT != null && (queryOPT.getFlags() & ExtendedFlags.DO) != 0)
flags = FLAG_DNSSECOK;
final Message response = new Message(query.getHeader().getID());
response.getHeader().setFlag(Flags.QR);
if (query.getHeader().getFlag(Flags.RD)) {
response.getHeader().setFlag(Flags.RD);
}
response.addRecord(queryRecord, Section.QUESTION);
final Name name = queryRecord.getName();
final int type = queryRecord.getType();
final int dclass = queryRecord.getDClass();
if ((type == Type.AXFR || type == Type.IXFR) && s != null)
return doAXFR(name, query, tsig, queryTSIG, s);
if (!Type.isRR(type) && type != Type.ANY)
return errorMessage(query, Rcode.NOTIMP);
final byte rcode = addAnswer(response, name, type, dclass, 0, flags);
if (rcode != Rcode.NOERROR && rcode != Rcode.NXDOMAIN)
return errorMessage(query, rcode);
addAdditional(response, flags);
if (queryOPT != null) {
final int optflags = (flags == FLAG_DNSSECOK) ? ExtendedFlags.DO : 0;
final OPTRecord opt = new OPTRecord((short) 4096, rcode, (byte) 0, optflags);
response.addRecord(opt, Section.ADDITIONAL);
}
response.setTSIG(tsig, Rcode.NOERROR, queryTSIG);
return response.toWire(maxLength);
}
byte[] buildErrorMessage(final Header header, final int rcode, final Record question) {
final Message response = new Message();
response.setHeader(header);
for (int i = 0; i < 4; i++)
response.removeAllRecords(i);
if (rcode == Rcode.SERVFAIL)
response.addRecord(question, Section.QUESTION);
header.setRcode(rcode);
return response.toWire();
}
public byte[] formerrMessage(final byte[] in) {
try {
return buildErrorMessage(new Header(in), Rcode.FORMERR, null);
} catch (final IOException e) {
LogUtils.debugf(this, e, "unable to build error message");
return null;
}
}
public byte[] errorMessage(final Message query, final int rcode) {
return buildErrorMessage(query.getHeader(), rcode, query.getQuestion());
}
}