/* Copyright (c) 2008, 2012, Oracle and/or its affiliates. All rights reserved. The MySQL Connector/J is licensed under the terms of the GPLv2 <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most MySQL Connectors. There are special exceptions to the terms and conditions of the GPLv2 as it is applied to this software, see the FLOSS License Exception <http://www.mysql.com/about/legal/licensing/foss-exception.html>. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; version 2 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */ package testsuite; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; import java.nio.channels.SocketChannel; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Properties; import java.util.Set; import com.mysql.jdbc.NonRegisteringDriver; import com.mysql.jdbc.SocketFactory; import com.mysql.jdbc.StandardSocketFactory; /** * Configure "socketFactory" to use this class in your JDBC URL, and it will operate * as normal, unless you map some host aliases to actual IP addresses, and then have * the test driver call hangOnConnect/Read/Write() which simulate the given failure * condition for the host with the <b>alias</b> argument, and will honor connect or * socket timeout properties. * * You can also cause a host to be immediately-downed by calling downHost() with an alias. */ public class UnreliableSocketFactory extends StandardSocketFactory { public static final long DEFAULT_TIMEOUT_MILLIS = 10 * 60 * 1000; // ugh private static final Map<String, String> MAPPED_HOSTS = new HashMap<String, String>(); static final Set<String> HUNG_READ_HOSTS = new HashSet<String>(); static final Set<String> HUNG_WRITE_HOSTS = new HashSet<String>(); static final Set<String> HUNG_CONNECT_HOSTS = new HashSet<String>(); static final Set<String> IMMEDIATELY_DOWNED_HOSTS = new HashSet<String>(); private String hostname; private int portNumber; private Properties props; public static void flushAllHostLists(){ IMMEDIATELY_DOWNED_HOSTS.clear(); HUNG_CONNECT_HOSTS.clear(); HUNG_READ_HOSTS.clear(); HUNG_WRITE_HOSTS.clear(); } public static void mapHost(String alias, String orig) { MAPPED_HOSTS.put(alias, orig); } public static void hangOnRead(String hostname) { HUNG_READ_HOSTS.add(hostname); } public static void dontHangOnRead(String hostname) { HUNG_READ_HOSTS.remove(hostname); } public static void hangOnWrite(String hostname) { HUNG_WRITE_HOSTS.add(hostname); } public static void dontHangOnWrite (String hostname) { HUNG_WRITE_HOSTS.remove(hostname); } public static void hangOnConnect(String hostname) { HUNG_CONNECT_HOSTS.add(hostname); } public static void dontHangOnConnect(String hostname) { HUNG_CONNECT_HOSTS.remove(hostname); } public static void downHost(String hostname) { IMMEDIATELY_DOWNED_HOSTS.add(hostname); } public static void dontDownHost(String hostname) { IMMEDIATELY_DOWNED_HOSTS.remove(hostname); } public Socket connect(String host_name, int port_number, Properties prop) throws SocketException, IOException { this.hostname = host_name; this.portNumber = port_number; this.props = prop; return getNewSocket(); } private Socket getNewSocket() throws SocketException, IOException { if (IMMEDIATELY_DOWNED_HOSTS.contains(hostname)) { sleepMillisForProperty(props, "connectTimeout"); throw new SocketTimeoutException(); } String hostnameToConnectTo = MAPPED_HOSTS.get(hostname); if (hostnameToConnectTo == null) { hostnameToConnectTo = hostname; } if (NonRegisteringDriver.isHostPropertiesList(hostnameToConnectTo)) { Properties hostSpecificProps = NonRegisteringDriver.expandHostKeyValues(hostnameToConnectTo); String protocol = hostSpecificProps.getProperty(NonRegisteringDriver.PROTOCOL_PROPERTY_KEY); if ("unix".equalsIgnoreCase(protocol)) { SocketFactory factory; try { factory = (SocketFactory) Class .forName( "org.newsclub.net.mysql.AFUNIXDatabaseSocketFactory") .newInstance(); } catch (InstantiationException e) { throw new SocketException(e.getMessage()); } catch (IllegalAccessException e) { throw new SocketException(e.getMessage()); } catch (ClassNotFoundException e) { throw new SocketException(e.getMessage()); } String path = hostSpecificProps .getProperty(NonRegisteringDriver.PATH_PROPERTY_KEY); if (path != null) { hostSpecificProps.setProperty("junixsocket.file", path); } return new HangingSocket(factory.connect(hostnameToConnectTo, portNumber, hostSpecificProps), props, hostname); } } return new HangingSocket(super.connect(hostnameToConnectTo, portNumber, props), props, hostname); } public Socket afterHandshake() throws SocketException, IOException { return getNewSocket(); } public Socket beforeHandshake() throws SocketException, IOException { return getNewSocket(); } static void sleepMillisForProperty(Properties props, String name) { try { Thread.sleep(Long.parseLong(props.getProperty(name, String .valueOf(DEFAULT_TIMEOUT_MILLIS)))); } catch (NumberFormatException e) { throw new RuntimeException(e); } catch (InterruptedException e) { // ignore } } class HangingSocket extends Socket { public void bind(SocketAddress bindpoint) throws IOException { underlyingSocket.bind(bindpoint); } public synchronized void close() throws IOException { underlyingSocket.close(); } public SocketChannel getChannel() { return underlyingSocket.getChannel(); } public InetAddress getInetAddress() { return underlyingSocket.getInetAddress(); } public InputStream getInputStream() throws IOException { return new HangingInputStream(underlyingSocket.getInputStream(), props, aliasedHostname); } public boolean getKeepAlive() throws SocketException { return underlyingSocket.getKeepAlive(); } public InetAddress getLocalAddress() { return underlyingSocket.getLocalAddress(); } public int getLocalPort() { return underlyingSocket.getLocalPort(); } public SocketAddress getLocalSocketAddress() { return underlyingSocket.getLocalSocketAddress(); } public boolean getOOBInline() throws SocketException { return underlyingSocket.getOOBInline(); } public OutputStream getOutputStream() throws IOException { return new HangingOutputStream(underlyingSocket.getOutputStream(), props, aliasedHostname); } public int getPort() { return underlyingSocket.getPort(); } public synchronized int getReceiveBufferSize() throws SocketException { return underlyingSocket.getReceiveBufferSize(); } public SocketAddress getRemoteSocketAddress() { return underlyingSocket.getRemoteSocketAddress(); } public boolean getReuseAddress() throws SocketException { return underlyingSocket.getReuseAddress(); } public synchronized int getSendBufferSize() throws SocketException { return underlyingSocket.getSendBufferSize(); } public int getSoLinger() throws SocketException { return underlyingSocket.getSoLinger(); } public synchronized int getSoTimeout() throws SocketException { return underlyingSocket.getSoTimeout(); } public boolean getTcpNoDelay() throws SocketException { return underlyingSocket.getTcpNoDelay(); } public int getTrafficClass() throws SocketException { return underlyingSocket.getTrafficClass(); } public boolean isBound() { return underlyingSocket.isBound(); } public boolean isClosed() { return underlyingSocket.isClosed(); } public boolean isConnected() { return underlyingSocket.isConnected(); } public boolean isInputShutdown() { return underlyingSocket.isInputShutdown(); } public boolean isOutputShutdown() { return underlyingSocket.isOutputShutdown(); } public void sendUrgentData(int data) throws IOException { underlyingSocket.sendUrgentData(data); } public void setKeepAlive(boolean on) throws SocketException { underlyingSocket.setKeepAlive(on); } public void setOOBInline(boolean on) throws SocketException { underlyingSocket.setOOBInline(on); } public synchronized void setReceiveBufferSize(int size) throws SocketException { underlyingSocket.setReceiveBufferSize(size); } public void setReuseAddress(boolean on) throws SocketException { underlyingSocket.setReuseAddress(on); } public synchronized void setSendBufferSize(int size) throws SocketException { underlyingSocket.setSendBufferSize(size); } public void setSoLinger(boolean on, int linger) throws SocketException { underlyingSocket.setSoLinger(on, linger); } public synchronized void setSoTimeout(int timeout) throws SocketException { underlyingSocket.setSoTimeout(timeout); } public void setTcpNoDelay(boolean on) throws SocketException { underlyingSocket.setTcpNoDelay(on); } public void setTrafficClass(int tc) throws SocketException { underlyingSocket.setTrafficClass(tc); } public void shutdownInput() throws IOException { underlyingSocket.shutdownInput(); } public void shutdownOutput() throws IOException { underlyingSocket.shutdownOutput(); } public String toString() { return underlyingSocket.toString(); } final Socket underlyingSocket; final Properties props; final String aliasedHostname; HangingSocket(Socket realSocket, Properties props, String aliasedHostname) { underlyingSocket = realSocket; this.props = props; this.aliasedHostname = aliasedHostname; } } static class HangingInputStream extends InputStream { final InputStream underlyingInputStream; final Properties props; final String aliasedHostname; HangingInputStream(InputStream realInputStream, Properties props, String aliasedHostname) { underlyingInputStream = realInputStream; this.props = props; this.aliasedHostname = aliasedHostname; } public int available() throws IOException { return underlyingInputStream.available(); } public void close() throws IOException { underlyingInputStream.close(); } public synchronized void mark(int readlimit) { underlyingInputStream.mark(readlimit); } public boolean markSupported() { return underlyingInputStream.markSupported(); } public int read(byte[] b, int off, int len) throws IOException { failIfRequired(); return underlyingInputStream.read(b, off, len); } public int read(byte[] b) throws IOException { failIfRequired(); return underlyingInputStream.read(b); } public synchronized void reset() throws IOException { underlyingInputStream.reset(); } public long skip(long n) throws IOException { failIfRequired(); return underlyingInputStream.skip(n); } private void failIfRequired() throws SocketTimeoutException { if (HUNG_READ_HOSTS.contains(aliasedHostname) || IMMEDIATELY_DOWNED_HOSTS.contains(aliasedHostname)) { sleepMillisForProperty(props, "socketTimeout"); throw new SocketTimeoutException(); } } public int read() throws IOException { failIfRequired(); return underlyingInputStream.read(); } } static class HangingOutputStream extends OutputStream { final Properties props; final String aliasedHostname; final OutputStream underlyingOutputStream; HangingOutputStream(OutputStream realOutputStream, Properties props, String aliasedHostname) { underlyingOutputStream = realOutputStream; this.props = props; this.aliasedHostname = aliasedHostname; } public void close() throws IOException { failIfRequired(); underlyingOutputStream.close(); } public void flush() throws IOException { underlyingOutputStream.flush(); } public void write(byte[] b, int off, int len) throws IOException { failIfRequired(); underlyingOutputStream.write(b, off, len); } public void write(byte[] b) throws IOException { failIfRequired(); underlyingOutputStream.write(b); } public void write(int b) throws IOException { failIfRequired(); underlyingOutputStream.write(b); } private void failIfRequired() throws SocketTimeoutException { if (HUNG_WRITE_HOSTS.contains(aliasedHostname) || IMMEDIATELY_DOWNED_HOSTS.contains(aliasedHostname)) { sleepMillisForProperty(props, "socketTimeout"); throw new SocketTimeoutException(); } } } }