/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package libthrift091; import org.apache.log4j.Logger; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; /** * This class uses a single thread to set up non-blocking sockets to a set * of remote servers (hostname and port pairs), and sends a same request to * all these servers. It then fetches responses from servers. * * Parameters: * int maxRecvBufBytesPerServer - an upper limit for receive buffer size * per server (in byte). If a response from a server exceeds this limit, the * client will not allocate memory or read response data for it. * * int fetchTimeoutSeconds - time limit for fetching responses from all * servers (in second). After the timeout, the fetch job is stopped and * available responses are returned. * * ByteBuffer requestBuf - request message that is sent to all servers. * * Output: * Responses are stored in an array of ByteBuffers. Index of elements in * this array corresponds to index of servers in the server list. Content in * a ByteBuffer may be in one of the following forms: * 1. First 4 bytes form an integer indicating length of following data, * then followed by the data. * 2. First 4 bytes form an integer indicating length of following data, * then followed by nothing - this happens when the response data size * exceeds maxRecvBufBytesPerServer, and the client will not read any * response data. * 3. No data in the ByteBuffer - this happens when the server does not * return any response within fetchTimeoutSeconds. * * In some special cases (no servers are given, fetchTimeoutSeconds less * than or equal to 0, requestBuf is null), the return is null. * * Note: * It assumes all remote servers are TNonblockingServers and use * TFramedTransport. * */ public class TNonblockingMultiFetchClient { private static final Logger LOG = Logger.getLogger( TNonblockingMultiFetchClient.class); // if the size of the response msg exceeds this limit (in byte), we will // not read the msg private int maxRecvBufBytesPerServer; // time limit for fetching data from all servers (in second) private int fetchTimeoutSeconds; // store request that will be sent to servers private ByteBuffer requestBuf; private ByteBuffer requestBufDuplication; // a list of remote servers private List<InetSocketAddress> servers; // store fetch results private TNonblockingMultiFetchStats stats; private ByteBuffer[] recvBuf; public TNonblockingMultiFetchClient(int maxRecvBufBytesPerServer, int fetchTimeoutSeconds, ByteBuffer requestBuf, List<InetSocketAddress> servers) { this.maxRecvBufBytesPerServer = maxRecvBufBytesPerServer; this.fetchTimeoutSeconds = fetchTimeoutSeconds; this.requestBuf = requestBuf; this.servers = servers; stats = new TNonblockingMultiFetchStats(); recvBuf = null; } public synchronized int getMaxRecvBufBytesPerServer() { return maxRecvBufBytesPerServer; } public synchronized int getFetchTimeoutSeconds() { return fetchTimeoutSeconds; } /** * return a duplication of requestBuf, so that requestBuf will not * be modified by others. */ public synchronized ByteBuffer getRequestBuf() { if (requestBuf == null) { return null; } else { if (requestBufDuplication == null) { requestBufDuplication = requestBuf.duplicate(); } return requestBufDuplication; } } public synchronized List<InetSocketAddress> getServerList() { if (servers == null) { return null; } return Collections.unmodifiableList(servers); } public synchronized TNonblockingMultiFetchStats getFetchStats() { return stats; } /** * main entry function for fetching from servers */ public synchronized ByteBuffer[] fetch() { // clear previous results recvBuf = null; stats.clear(); if (servers == null || servers.size() == 0 || requestBuf == null || fetchTimeoutSeconds <= 0) { return recvBuf; } ExecutorService executor = Executors.newSingleThreadExecutor(); MultiFetch multiFetch = new MultiFetch(); FutureTask<?> task = new FutureTask(multiFetch, null); executor.execute(task); try { task.get(fetchTimeoutSeconds, TimeUnit.SECONDS); } catch(InterruptedException ie) { // attempt to cancel execution of the task. task.cancel(true); LOG.error("interrupted during fetch: "+ie.toString()); } catch(ExecutionException ee) { // attempt to cancel execution of the task. task.cancel(true); LOG.error("exception during fetch: "+ee.toString()); } catch(TimeoutException te) { // attempt to cancel execution of the task. task.cancel(true); LOG.error("timeout for fetch: "+te.toString()); } executor.shutdownNow(); multiFetch.close(); return recvBuf; } /** * Private class that does real fetch job. * Users are not allowed to directly use this class, as its run() * function may run forever. */ private class MultiFetch implements Runnable { private Selector selector; /** * main entry function for fetching. * * Server responses are stored in TNonblocingMultiFetchClient.recvBuf, * and fetch statistics is in TNonblockingMultiFetchClient.stats. * * Sanity check for parameters has been done in * TNonblockingMultiFetchClient before calling this function. */ public void run() { long t1 = System.currentTimeMillis(); int numTotalServers = servers.size(); stats.setNumTotalServers(numTotalServers); // buffer for receiving response from servers recvBuf = new ByteBuffer[numTotalServers]; // buffer for sending request ByteBuffer sendBuf[] = new ByteBuffer[numTotalServers]; long numBytesRead[] = new long[numTotalServers]; int frameSize[] = new int[numTotalServers]; boolean hasReadFrameSize[] = new boolean[numTotalServers]; try { selector = Selector.open(); } catch (IOException e) { LOG.error("selector opens error: "+e.toString()); return; } for (int i = 0; i < numTotalServers; i++) { // create buffer to send request to server. sendBuf[i] = requestBuf.duplicate(); // create buffer to read response's frame size from server recvBuf[i] = ByteBuffer.allocate(4); stats.incTotalRecvBufBytes(4); InetSocketAddress server = servers.get(i); SocketChannel s = null; SelectionKey key = null; try { s = SocketChannel.open(); s.configureBlocking(false); // now this method is non-blocking s.connect(server); key = s.register(selector, s.validOps()); // attach index of the key key.attach(i); } catch (Exception e) { stats.incNumConnectErrorServers(); String err = String.format("set up socket to server %s error: %s", server.toString(), e.toString()); LOG.error(err); // free resource if (s != null) { try {s.close();} catch (Exception ex) {} } if (key != null) { key.cancel(); } } } // wait for events while (stats.getNumReadCompletedServers() + stats.getNumConnectErrorServers() < stats.getNumTotalServers()) { // if the thread is interrupted (e.g., task is cancelled) if (Thread.currentThread().isInterrupted()) { return; } try{ selector.select(); } catch (Exception e) { LOG.error("selector selects error: "+e.toString()); continue; } Iterator<SelectionKey> it = selector.selectedKeys().iterator(); while (it.hasNext()) { SelectionKey selKey = it.next(); it.remove(); // get previously attached index int index = (Integer)selKey.attachment(); if (selKey.isValid() && selKey.isConnectable()) { // if this socket throws an exception (e.g., connection refused), // print error msg and skip it. try { SocketChannel sChannel = (SocketChannel)selKey.channel(); sChannel.finishConnect(); } catch (Exception e) { stats.incNumConnectErrorServers(); String err = String.format("socket %d connects to server %s " + "error: %s", index, servers.get(index).toString(), e.toString()); LOG.error(err); } } if (selKey.isValid() && selKey.isWritable()) { if (sendBuf[index].hasRemaining()) { // if this socket throws an exception, print error msg and // skip it. try { SocketChannel sChannel = (SocketChannel)selKey.channel(); sChannel.write(sendBuf[index]); } catch (Exception e) { String err = String.format("socket %d writes to server %s " + "error: %s", index, servers.get(index).toString(), e.toString()); LOG.error(err); } } } if (selKey.isValid() && selKey.isReadable()) { // if this socket throws an exception, print error msg and // skip it. try { SocketChannel sChannel = (SocketChannel)selKey.channel(); int bytesRead = sChannel.read(recvBuf[index]); if (bytesRead > 0) { numBytesRead[index] += bytesRead; if (!hasReadFrameSize[index] && recvBuf[index].remaining()==0) { // if the frame size has been read completely, then prepare // to read the actual frame. frameSize[index] = recvBuf[index].getInt(0); if (frameSize[index] <= 0) { stats.incNumInvalidFrameSize(); String err = String.format("Read an invalid frame size %d" + " from %s. Does the server use TFramedTransport? ", frameSize[index], servers.get(index).toString()); LOG.error(err); sChannel.close(); continue; } if (frameSize[index] + 4 > stats.getMaxResponseBytes()) { stats.setMaxResponseBytes(frameSize[index]+4); } if (frameSize[index] + 4 > maxRecvBufBytesPerServer) { stats.incNumOverflowedRecvBuf(); String err = String.format("Read frame size %d from %s," + " total buffer size would exceed limit %d", frameSize[index], servers.get(index).toString(), maxRecvBufBytesPerServer); LOG.error(err); sChannel.close(); continue; } // reallocate buffer for actual frame data recvBuf[index] = ByteBuffer.allocate(frameSize[index] + 4); recvBuf[index].putInt(frameSize[index]); stats.incTotalRecvBufBytes(frameSize[index]); hasReadFrameSize[index] = true; } if (hasReadFrameSize[index] && numBytesRead[index] >= frameSize[index]+4) { // has read all data sChannel.close(); stats.incNumReadCompletedServers(); long t2 = System.currentTimeMillis(); stats.setReadTime(t2-t1); } } } catch (Exception e) { String err = String.format("socket %d reads from server %s " + "error: %s", index, servers.get(index).toString(), e.toString()); LOG.error(err); } } } } } /** * dispose any resource allocated */ public void close() { try { if (selector.isOpen()) { Iterator<SelectionKey> it = selector.keys().iterator(); while (it.hasNext()) { SelectionKey selKey = it.next(); SocketChannel sChannel = (SocketChannel)selKey.channel(); sChannel.close(); } selector.close(); } } catch (IOException e) { LOG.error("free resource error: "+e.toString()); } } } }