/**
* VMware Continuent Tungsten Replicator
* Copyright (C) 2015 VMware, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Initial developer(s): Robert Hodges
* Contributor(s):
*/
package com.continuent.tungsten.common.sockets;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.channels.SocketChannel;
import javax.net.SocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import org.apache.log4j.Logger;
import com.continuent.tungsten.common.config.cluster.ConfigurationException;
import com.continuent.tungsten.common.security.SecurityHelper;
/**
* Provides a wrapper for client connections via sockets. This class
* encapsulates logic for timeouts, SSL vs. non-SSL operation, and closing the
* connection. This class assumes properties required for SSL operation have
* been previously set before SSL sockets are allocated.
*
* @author <a href="mailto:robert.hodges@continuent.com">Robert Hodges</a>
*/
public class ClientSocketWrapper extends SocketWrapper
{
private static Logger logger = Logger.getLogger(ClientSocketWrapper.class);
// Properties
InetSocketAddress address;
private boolean useSSL;
private int connectTimeout;
private int readTimeout;
String[] enabledProtocols;
String[] enabledCiphers;
/**
* Returns the enabledProtocols value.
*
* @return Returns the enabledProtocols.
*/
public String[] getEnabledProtocols()
{
return enabledProtocols;
}
/**
* Sets the enabledProtocols value.
*
* @param enabledProtocols The enabledProtocols to set.
*/
public void setEnabledProtocols(String[] enabledProtocols)
{
this.enabledProtocols = enabledProtocols;
}
// Socket factory for new SSL connections.
private SocketFactory sslFactory;
// Flag to signal service has been shut down.
private volatile boolean done = false;
/** Creates a new wrapper for client connections. */
public ClientSocketWrapper()
{
super(null);
}
public InetSocketAddress getAddress()
{
return address;
}
/** Sets the address to which we should connect. */
public void setAddress(InetSocketAddress address)
{
this.address = address;
}
public boolean isUseSSL()
{
return useSSL;
}
/** If set to true, use an SSL socket, otherwise use plain TCP/IP. */
public void setUseSSL(boolean useSSL)
{
this.useSSL = useSSL;
}
public long getConnectTimeout()
{
return connectTimeout;
}
/** Time in milliseconds before timeout when connecting to a server. */
public void setConnectTimeout(int connectTimeout)
{
this.connectTimeout = connectTimeout;
}
public void setSoTimeout(int timeout) throws SocketException
{
this.socket.setSoTimeout(timeout);
}
public long getReadTimeout()
{
return readTimeout;
}
/**
* Time in milliseconds before timeout when waiting for responses after
* connection.
*/
public void setReadTimeout(int readTimeout)
{
this.readTimeout = readTimeout;
}
public void setTcpNoDelay(boolean value) throws SocketException
{
this.socket.setTcpNoDelay(value);
}
/**
* Connect to the server.
*
* @throws ConfigurationException
*/
public Socket connect() throws IOException, ConfigurationException
{
// Create the socket.
if (useSSL)
{
// Create an SSL socket.
sslFactory = SSLSocketFactory.getDefault();
SSLSocket sslSocket = (SSLSocket) sslFactory.createSocket();
// Check that at least one configured cipher and protocol match with
// those supported by the socket
SecurityHelper.setCiphersAndProtocolsToSSLSocket(sslSocket,
SecurityHelper.getCiphers(), SecurityHelper.getProtocols());
socket = sslSocket;
}
else
{
SocketChannel channel = SocketChannel.open();
socket = channel.socket();
}
// Store the socket in the super class.
setSocket(socket);
// Try to connect using the connect timeout.
try
{
socket.connect(address, connectTimeout);
}
catch (IOException e)
{
if (done)
{
throw new SocketTerminationException(
"Socket has been terminated", e);
}
else
{
throw e;
}
}
// Disable Nagle's algorithm
socket.setTcpNoDelay(true);
// Enable TCP keepalive
socket.setKeepAlive(true);
// Set the socket timeout for reads.
socket.setSoTimeout(readTimeout);
return socket;
}
/** Returns the socket. */
public Socket getSocket()
{
return this.socket;
}
/**
* {@inheritDoc}
*
* @see com.continuent.tungsten.common.sockets.SocketWrapper#getInputStream()
*/
@Override
public InputStream getInputStream() throws IOException
{
return socket.getInputStream();
}
/**
* {@inheritDoc}
*
* @see com.continuent.tungsten.common.sockets.SocketWrapper#getOutputStream()
*/
@Override
public OutputStream getOutputStream() throws IOException
{
return socket.getOutputStream();
}
/**
* Close socket. This is synchronized to prevent accidental double calls.
*/
public synchronized void close()
{
done = true;
if (socket != null)
{
try
{
socket.close();
}
catch (IOException e)
{
logger.warn(e.getMessage());
}
finally
{
socket = null;
}
}
}
}