/* This file is part of VoltDB. * Copyright (C) 2008-2010 VoltDB L.L.C. * * VoltDB is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * VoltDB is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with VoltDB. If not, see <http://www.gnu.org/licenses/>. */ package org.voltdb.client; import java.io.EOFException; import java.io.IOException; import java.net.Inet4Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; import java.net.SocketException; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Enumeration; import java.util.HashMap; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import org.apache.log4j.Logger; import org.voltdb.ClientResponseImpl; import org.voltdb.StoredProcedureInvocation; import org.voltdb.messaging.FastDeserializer; import org.voltdb.messaging.FastSerializer; import org.voltdb.utils.DBBPool.BBContainer; /** * A utility class for opening a connection to a Volt server and authenticating as well * as sending invocations and receiving responses. It is safe to queue multiple requests * @author aweisberg * */ public class ConnectionUtil { private static final Logger LOG = Logger.getLogger(ConnectionUtil.class); private static class TF implements ThreadFactory { @Override public Thread newThread(Runnable r) { return new Thread(null, r, "Yet another thread", 65536); } } private static final TF m_tf = new TF(); public static class ExecutorPair { public final ExecutorService m_writeExecutor; public final ExecutorService m_readExecutor; public ExecutorPair() { m_writeExecutor = Executors.newSingleThreadExecutor(m_tf); m_readExecutor = Executors.newSingleThreadExecutor(m_tf); } private void shutdown() throws InterruptedException { m_readExecutor.shutdownNow(); m_writeExecutor.shutdownNow(); m_readExecutor.awaitTermination(1, TimeUnit.DAYS); m_writeExecutor.awaitTermination(1, TimeUnit.DAYS); } } private static final HashMap<SocketChannel, ExecutorPair> m_executors = new HashMap<SocketChannel, ExecutorPair>(); private static final AtomicLong m_handle = new AtomicLong(Long.MIN_VALUE); /** * Create a connection to a Volt server and authenticate the connection. * @param host * @param username * @param password * @param port * @throws IOException * @returns An array of objects. The first is an * authenticated socket channel, the second. is an array of 4 longs - * Integer hostId, Long connectionId, Long timestamp (part of instanceId), Int leaderAddress (part of instanceId). * The last object is the build string */ public static Object[] getAuthenticatedConnection( String host, String username, String password, int port) throws IOException { return getAuthenticatedConnection("database", host, username, password, port); } /** * Create a connection to a Volt server for export and authenticate the connection. * @param host * @param username * @param password * @param port * @throws IOException * @returns An array of objects. The first is an * authenticated socket channel, the second. is an array of 4 longs - * Integer hostId, Long connectionId, Long timestamp (part of instanceId), Int leaderAddress (part of instanceId). * The last object is the build string */ public static Object[] getAuthenticatedExportConnection( String host, String username, String password, int port) throws IOException { return getAuthenticatedConnection("export", host, username, password, port); } private static Object[] getAuthenticatedConnection( String service, String host, String username, String password, int port) throws IOException { LOG.debug("Ok, so now we're looking for an authenticated connection"); LOG.debug("[service=" + service + ", host=" + host + ", user=" + username + ", pass=" + password + ", port=" + port + "]"); Object returnArray[] = new Object[3]; boolean success = false; InetSocketAddress addr = new InetSocketAddress(host, port); SocketChannel aChannel = SocketChannel.open(addr); returnArray[0] = aChannel; assert(aChannel.isConnected()); if (!aChannel.isConnected()) { // TODO Can open() be asynchronous if configureBlocking(true)? throw new IOException("Failed to open host " + host); } final long retvals[] = new long[4]; returnArray[1] = retvals; try { /* * Send login info */ aChannel.configureBlocking(true); aChannel.socket().setTcpNoDelay(true); MessageDigest md = null; try { md = MessageDigest.getInstance("SHA-1"); } catch (NoSuchAlgorithmException e) { e.printStackTrace(); System.exit(-1); } byte passwordHash[] = md.digest(password.getBytes()); FastSerializer fs = new FastSerializer(); fs.writeInt(0); // placeholder for length fs.writeByte(0); // version fs.writeString(service); // data service (export|database) fs.writeString(username); fs.write(passwordHash); final ByteBuffer fsBuffer = fs.getBuffer(); final ByteBuffer b = ByteBuffer.allocate(fsBuffer.remaining()); b.put(fsBuffer); final int size = fsBuffer.limit() - 4; b.flip(); b.putInt(size); b.position(0); boolean successfulWrite = false; IOException writeException = null; try { for (int ii = 0; ii < 4 && b.hasRemaining(); ii++) { aChannel.write(b); } if (!b.hasRemaining()) { successfulWrite = true; } } catch (IOException e) { writeException = e; } ByteBuffer lengthBuffer = ByteBuffer.allocate(4); int read = aChannel.read(lengthBuffer); if (read == -1) { if (writeException != null) { throw writeException; } if (!successfulWrite) { throw new IOException("Unable to write authentication info to serer"); } throw new IOException("Authentication rejected"); } else { lengthBuffer.flip(); } ByteBuffer loginResponse = ByteBuffer.allocate(lengthBuffer.getInt());//Read version and length etc. read = aChannel.read(loginResponse); byte loginResponseCode = 0; if (read == -1) { if (writeException != null) { throw writeException; } if (!successfulWrite) { throw new IOException("Unable to write authentication info to serer"); } throw new IOException("Authentication rejected"); } else { loginResponse.flip(); loginResponse.position(1); loginResponseCode = loginResponse.get(); } if (loginResponseCode != 0) { aChannel.close(); switch (loginResponseCode) { case 1: throw new IOException("Server has too many connections"); case 2: throw new IOException("Connection timed out during authentication. " + "Buy a faster computer and stop using VMWare"); default: throw new IOException("Authentication rejected"); } } retvals[0] = loginResponse.getInt(); retvals[1] = loginResponse.getLong(); retvals[2] = loginResponse.getLong(); retvals[3] = loginResponse.getInt(); int buildStringLength = loginResponse.getInt(); byte buildStringBytes[] = new byte[buildStringLength]; loginResponse.get(buildStringBytes); returnArray[2] = new String(buildStringBytes, "UTF-8"); aChannel.configureBlocking(false); aChannel.socket().setTcpNoDelay(false); aChannel.socket().setKeepAlive(true); success = true; } finally { if (!success) { aChannel.close(); } } return returnArray; } public static void closeConnection(SocketChannel connection) throws InterruptedException, IOException { synchronized (m_executors) { ExecutorPair p = m_executors.remove(connection); assert(p != null); p.shutdown(); } connection.close(); } private static ExecutorPair getExecutorPair(final SocketChannel channel) { synchronized (m_executors) { ExecutorPair p = m_executors.get(channel); if (p == null) { p = new ExecutorPair(); m_executors.put( channel, p); } return p; } } public static Future<Long> sendInvocation(final SocketChannel channel, final String procName,final Object ...parameters) { final ExecutorPair p = getExecutorPair(channel); return sendInvocation(p.m_writeExecutor, channel, procName, parameters); } public static Future<Long> sendInvocation(final ExecutorService executor, final SocketChannel channel, final String procName,final Object ...parameters) { return executor.submit(new Callable<Long>() { @Override public Long call() throws Exception { final long handle = m_handle.getAndIncrement(); final StoredProcedureInvocation invocation = new StoredProcedureInvocation(handle, procName, -1, parameters); final FastSerializer fs = new FastSerializer(); final BBContainer c = fs.writeObjectForMessaging(invocation); do { channel.write(c.b); if (c.b.hasRemaining()) { Thread.yield(); } } while(c.b.hasRemaining()); c.discard(); return handle; } }); } public static Future<ClientResponse> readResponse(final SocketChannel channel) { final ExecutorPair p = getExecutorPair(channel); return readResponse(p.m_readExecutor, channel); } public static Future<ClientResponse> readResponse(final ExecutorService executor, final SocketChannel channel) { return executor.submit(new Callable<ClientResponse>() { @Override public ClientResponse call() throws Exception { ByteBuffer lengthBuffer = ByteBuffer.allocate(4); do { final int read = channel.read(lengthBuffer); if (read == -1) { throw new EOFException(); } if (lengthBuffer.hasRemaining()) { Thread.yield(); } } while (lengthBuffer.hasRemaining()); lengthBuffer.flip(); ByteBuffer message = ByteBuffer.allocate(lengthBuffer.getInt()); do { final int read = channel.read(message); if (read == -1) { throw new EOFException(); } if (lengthBuffer.hasRemaining()) { Thread.yield(); } } while (message.hasRemaining()); message.flip(); FastDeserializer fds = new FastDeserializer(message); ClientResponseImpl response = fds.readObject(ClientResponseImpl.class); return response; } }); } public static String getHostnameOrAddress() { try { final InetAddress addr = InetAddress.getLocalHost(); return addr.getHostName(); } catch (UnknownHostException e) { try { Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces(); if (interfaces == null) { return ""; } NetworkInterface intf = interfaces.nextElement(); Enumeration<InetAddress> addresses = intf.getInetAddresses(); while (addresses.hasMoreElements()) { InetAddress address = addresses.nextElement(); if (address instanceof Inet4Address) { return address.getHostAddress(); } } interfaces = NetworkInterface.getNetworkInterfaces(); while (addresses.hasMoreElements()) { return addresses.nextElement().getHostAddress(); } return ""; } catch (SocketException e1) { return ""; } } } public static InetAddress getLocalAddress() { try { final InetAddress addr = InetAddress.getLocalHost(); return addr; } catch (UnknownHostException e) { try { Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces(); if (interfaces == null) { return null; } NetworkInterface intf = interfaces.nextElement(); Enumeration<InetAddress> addresses = intf.getInetAddresses(); while (addresses.hasMoreElements()) { InetAddress address = addresses.nextElement(); if (address instanceof Inet4Address) { return address; } } interfaces = NetworkInterface.getNetworkInterfaces(); while (addresses.hasMoreElements()) { return addresses.nextElement(); } return null; } catch (SocketException e1) { return null; } } } }