// Copyright (c) 2005 Brian Wellington (bwelling@xbill.org) package io.milton.dns.record; import java.io.*; import java.net.*; import java.nio.*; import java.nio.channels.*; final class TCPClient extends Client { public TCPClient(long endTime) throws IOException { super(SocketChannel.open(), endTime); } void bind(SocketAddress addr) throws IOException { SocketChannel channel = (SocketChannel) key.channel(); channel.socket().bind(addr); } void connect(SocketAddress addr) throws IOException { SocketChannel channel = (SocketChannel) key.channel(); if (channel.connect(addr)) return; key.interestOps(SelectionKey.OP_CONNECT); try { while (!channel.finishConnect()) { if (!key.isConnectable()) blockUntil(key, endTime); } } finally { if (key.isValid()) key.interestOps(0); } } void send(byte [] data) throws IOException { SocketChannel channel = (SocketChannel) key.channel(); verboseLog("TCP write", data); byte [] lengthArray = new byte[2]; lengthArray[0] = (byte)(data.length >>> 8); lengthArray[1] = (byte)(data.length & 0xFF); ByteBuffer [] buffers = new ByteBuffer[2]; buffers[0] = ByteBuffer.wrap(lengthArray); buffers[1] = ByteBuffer.wrap(data); int nsent = 0; key.interestOps(SelectionKey.OP_WRITE); try { while (nsent < data.length + 2) { if (key.isWritable()) { long n = channel.write(buffers); if (n < 0) throw new EOFException(); nsent += (int) n; if (nsent < data.length + 2 && System.currentTimeMillis() > endTime) throw new SocketTimeoutException(); } else blockUntil(key, endTime); } } finally { if (key.isValid()) key.interestOps(0); } } private byte [] _recv(int length) throws IOException { SocketChannel channel = (SocketChannel) key.channel(); int nrecvd = 0; byte [] data = new byte[length]; ByteBuffer buffer = ByteBuffer.wrap(data); key.interestOps(SelectionKey.OP_READ); try { while (nrecvd < length) { if (key.isReadable()) { long n = channel.read(buffer); if (n < 0) throw new EOFException(); nrecvd += (int) n; if (nrecvd < length && System.currentTimeMillis() > endTime) throw new SocketTimeoutException(); } else blockUntil(key, endTime); } } finally { if (key.isValid()) key.interestOps(0); } return data; } byte [] recv() throws IOException { byte [] buf = _recv(2); int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF); byte [] data = _recv(length); verboseLog("TCP read", data); return data; } static byte [] sendrecv(SocketAddress local, SocketAddress remote, byte [] data, long endTime) throws IOException { TCPClient client = new TCPClient(endTime); try { if (local != null) client.bind(local); client.connect(remote); client.send(data); return client.recv(); } finally { client.cleanup(); } } static byte [] sendrecv(SocketAddress addr, byte [] data, long endTime) throws IOException { return sendrecv(null, addr, data, endTime); } }