// Copyright 2012 Citrix Systems, Inc. Licensed under the
// Apache License, Version 2.0 (the "License"); you may not use this
// file except in compliance with the License. Citrix Systems, Inc.
// reserves all rights not expressly granted by the License.
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Automatically generated by addcopyright.py at 04/03/2012
package com.cloud.utils.nio;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import java.security.KeyStore;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import org.apache.log4j.Logger;
import com.cloud.utils.PropertiesUtil;
/**
*/
public class Link {
private static final Logger s_logger = Logger.getLogger(Link.class);
private final InetSocketAddress _addr;
private final NioConnection _connection;
private SelectionKey _key;
private final ConcurrentLinkedQueue<ByteBuffer[]> _writeQueue;
private ByteBuffer _readBuffer;
private ByteBuffer _plaintextBuffer;
private Object _attach;
private boolean _readHeader;
private boolean _gotFollowingPacket;
private SSLEngine _sslEngine;
public Link(InetSocketAddress addr, NioConnection connection) {
_addr = addr;
_connection = connection;
_readBuffer = ByteBuffer.allocate(2048);
_attach = null;
_key = null;
_writeQueue = new ConcurrentLinkedQueue<ByteBuffer[]>();
_readHeader = true;
_gotFollowingPacket = false;
}
public Link (Link link) {
this(link._addr, link._connection);
}
public Object attachment() {
return _attach;
}
public void attach(Object attach) {
_attach = attach;
}
public void setKey(SelectionKey key) {
_key = key;
}
public void setSSLEngine(SSLEngine sslEngine) {
_sslEngine = sslEngine;
}
/**
* No user, so comment it out.
*
* Static methods for reading from a channel in case
* you need to add a client that doesn't require nio.
* @param ch channel to read from.
* @param bytebuffer to use.
* @return bytes read
* @throws IOException if not read to completion.
public static byte[] read(SocketChannel ch, ByteBuffer buff) throws IOException {
synchronized(buff) {
buff.clear();
buff.limit(4);
while (buff.hasRemaining()) {
if (ch.read(buff) == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
}
buff.flip();
int length = buff.getInt();
ByteArrayOutputStream output = new ByteArrayOutputStream(length);
WritableByteChannel outCh = Channels.newChannel(output);
int count = 0;
while (count < length) {
buff.clear();
int read = ch.read(buff);
if (read < 0) {
throw new IOException("Connection closed with -1 on reading data.");
}
count += read;
buff.flip();
outCh.write(buff);
}
return output.toByteArray();
}
}
*/
private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
SSLSession sslSession = sslEngine.getSession();
ByteBuffer pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
SSLEngineResult engResult;
ByteBuffer headBuf = ByteBuffer.allocate(4);
int totalLen = 0;
for (ByteBuffer buffer : buffers) {
totalLen += buffer.limit();
}
int processedLen = 0;
while (processedLen < totalLen) {
headBuf.clear();
pkgBuf.clear();
engResult = sslEngine.wrap(buffers, pkgBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
processedLen = 0;
for (ByteBuffer buffer : buffers) {
processedLen += buffer.position();
}
int dataRemaining = pkgBuf.position();
int header = dataRemaining;
int headRemaining = 4;
pkgBuf.flip();
if (processedLen < totalLen) {
header = header | HEADER_FLAG_FOLLOWING;
}
headBuf.putInt(header);
headBuf.flip();
while (headRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Header " + headRemaining);
}
long count = ch.write(headBuf);
headRemaining -= count;
}
while (dataRemaining > 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Writing Data " + dataRemaining);
}
long count = ch.write(pkgBuf);
dataRemaining -= count;
}
}
}
/**
* write method to write to a socket. This method writes to completion so
* it doesn't follow the nio standard. We use this to make sure we write
* our own protocol.
*
* @param ch channel to write to.
* @param buffers buffers to write.
* @throws IOException if unable to write to completion.
*/
public static void write(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException {
synchronized(ch) {
doWrite(ch, buffers, sslEngine);
}
}
/* SSL has limitation of 16k, we may need to split packets. 18000 is 16k + some extra SSL informations */
protected static final int MAX_SIZE_PER_PACKET = 18000;
protected static final int HEADER_FLAG_FOLLOWING = 0x10000;
public byte[] read(SocketChannel ch) throws IOException {
if (_readHeader) { // Start of a packet
if (_readBuffer.position() == 0) {
_readBuffer.limit(4);
}
if (ch.read(_readBuffer) == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
if (_readBuffer.hasRemaining()) {
s_logger.trace("Need to read the rest of the packet length");
return null;
}
_readBuffer.flip();
int header = _readBuffer.getInt();
int readSize = (short)header;
if (s_logger.isTraceEnabled()) {
s_logger.trace("Packet length is " + readSize);
}
if (readSize > MAX_SIZE_PER_PACKET) {
throw new IOException("Wrong packet size: " + readSize);
}
if (!_gotFollowingPacket) {
_plaintextBuffer = ByteBuffer.allocate(2000);
}
if ((header & HEADER_FLAG_FOLLOWING) != 0) {
_gotFollowingPacket = true;
} else {
_gotFollowingPacket = false;
}
_readBuffer.clear();
_readHeader = false;
if (_readBuffer.capacity() < readSize) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Resizing the byte buffer from " + _readBuffer.capacity());
}
_readBuffer = ByteBuffer.allocate(readSize);
}
_readBuffer.limit(readSize);
}
if (ch.read(_readBuffer) == -1) {
throw new IOException("Connection closed with -1 on read.");
}
if (_readBuffer.hasRemaining()) { // We're not done yet.
if (s_logger.isTraceEnabled()) {
s_logger.trace("Still has " + _readBuffer.remaining());
}
return null;
}
_readBuffer.flip();
ByteBuffer appBuf;
SSLSession sslSession = _sslEngine.getSession();
SSLEngineResult engResult;
int remaining = 0;
while (_readBuffer.hasRemaining()) {
remaining = _readBuffer.remaining();
appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
engResult = _sslEngine.unwrap(_readBuffer, appBuf);
if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED &&
engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("SSL: SSLEngine return bad result! " + engResult);
}
if (remaining == _readBuffer.remaining()) {
throw new IOException("SSL: Unable to unwrap received data! still remaining " + remaining + "bytes!");
}
appBuf.flip();
if (_plaintextBuffer.remaining() < appBuf.limit()) {
// We need to expand _plaintextBuffer for more data
ByteBuffer newBuffer = ByteBuffer.allocate(_plaintextBuffer.capacity() + appBuf.limit() * 5);
_plaintextBuffer.flip();
newBuffer.put(_plaintextBuffer);
_plaintextBuffer = newBuffer;
}
_plaintextBuffer.put(appBuf);
if (s_logger.isTraceEnabled()) {
s_logger.trace("Done with packet: " + appBuf.limit());
}
}
_readBuffer.clear();
_readHeader = true;
if (!_gotFollowingPacket) {
_plaintextBuffer.flip();
byte[] result = new byte[_plaintextBuffer.limit()];
_plaintextBuffer.get(result);
return result;
} else {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Waiting for more packets");
}
return null;
}
}
public void send(byte[] data) throws ClosedChannelException {
send(data, false);
}
public void send(byte[] data, boolean close) throws ClosedChannelException {
send(new ByteBuffer[] { ByteBuffer.wrap(data) }, close);
}
public void send(ByteBuffer[] data, boolean close) throws ClosedChannelException {
ByteBuffer[] item = new ByteBuffer[data.length + 1];
int remaining = 0;
for (int i = 0; i < data.length; i++) {
remaining += data[i].remaining();
item[i + 1] = data[i];
}
item[0] = ByteBuffer.allocate(4);
item[0].putInt(remaining);
item[0].flip();
if (s_logger.isTraceEnabled()) {
s_logger.trace("Sending packet of length " + remaining);
}
_writeQueue.add(item);
if (close) {
_writeQueue.add(new ByteBuffer[0]);
}
synchronized (this) {
if (_key == null) {
throw new ClosedChannelException();
}
_connection.change(SelectionKey.OP_WRITE, _key, null);
}
}
public void send(ByteBuffer[] data) throws ClosedChannelException {
send(data, false);
}
public synchronized void close() {
if (_key != null) {
_connection.close(_key);
}
}
public boolean write(SocketChannel ch) throws IOException {
ByteBuffer[] data = null;
while ((data = _writeQueue.poll()) != null) {
if (data.length == 0) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("Closing connection requested");
}
return true;
}
ByteBuffer[] raw_data = new ByteBuffer[data.length - 1];
System.arraycopy(data, 1, raw_data, 0, data.length - 1);
doWrite(ch, raw_data, _sslEngine);
}
return false;
}
public InetSocketAddress getSocketAddress() {
return _addr;
}
public String getIpAddress() {
return _addr.getAddress().toString();
}
public synchronized void terminated() {
_key = null;
}
public synchronized void schedule(Task task) throws ClosedChannelException {
if (_key == null) {
throw new ClosedChannelException();
}
_connection.scheduleTask(task);
}
public static SSLContext initSSLContext(boolean isClient) throws Exception {
InputStream stream;
SSLContext sslContext = null;
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
KeyStore ks = KeyStore.getInstance("JKS");
TrustManager[] tms;
if (!isClient) {
char[] passphrase = "vmops.com".toCharArray();
File confFile= PropertiesUtil.findConfigFile("db.properties");
/* This line may throw a NPE, but that's due to fail to find db.properities, meant some bugs in the other places */
String confPath = confFile.getParent();
String keystorePath = confPath + "/cloud.keystore";
if (new File(keystorePath).exists()) {
stream = new FileInputStream(keystorePath);
} else {
s_logger.warn("SSL: Fail to find the generated keystore. Loading fail-safe one to continue.");
stream = NioConnection.class.getResourceAsStream("/cloud.keystore");
}
ks.load(stream, passphrase);
stream.close();
kmf.init(ks, passphrase);
tmf.init(ks);
tms = tmf.getTrustManagers();
} else {
ks.load(null, null);
kmf.init(ks, null);
tms = new TrustManager[1];
tms[0] = new TrustAllManager();
}
sslContext = SSLContext.getInstance("TLS");
sslContext.init(kmf.getKeyManagers(), tms, null);
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: SSLcontext has been initialized");
}
return sslContext;
}
public static void doHandshake(SocketChannel ch, SSLEngine sslEngine,
boolean isClient) throws IOException {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: begin Handshake, isClient: " + isClient);
}
SSLEngineResult engResult;
SSLSession sslSession = sslEngine.getSession();
HandshakeStatus hsStatus;
ByteBuffer in_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer in_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
ByteBuffer out_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
ByteBuffer out_appBuf =
ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40);
int count;
if (isClient) {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_WRAP;
} else {
hsStatus = SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
}
while (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Handshake status " + hsStatus);
}
engResult = null;
if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
out_pkgBuf.clear();
out_appBuf.clear();
out_appBuf.put("Hello".getBytes());
engResult = sslEngine.wrap(out_appBuf, out_pkgBuf);
out_pkgBuf.flip();
int remain = out_pkgBuf.limit();
while (remain != 0) {
remain -= ch.write(out_pkgBuf);
if (remain < 0) {
throw new IOException("Too much bytes sent?");
}
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
in_appBuf.clear();
// One packet may contained multiply operation
if (in_pkgBuf.position() == 0 || !in_pkgBuf.hasRemaining()) {
in_pkgBuf.clear();
count = ch.read(in_pkgBuf);
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
in_pkgBuf.flip();
}
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
ByteBuffer tmp_pkgBuf =
ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40);
int loop_count = 0;
while (engResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
// The client is too slow? Cut it and let it reconnect
if (loop_count > 10) {
throw new IOException("Too many times in SSL BUFFER_UNDERFLOW, disconnect guest.");
}
// We need more packets to complete this operation
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Buffer underflowed, getting more packets");
}
tmp_pkgBuf.clear();
count = ch.read(tmp_pkgBuf);
if (count == -1) {
throw new IOException("Connection closed with -1 on reading size.");
}
tmp_pkgBuf.flip();
in_pkgBuf.mark();
in_pkgBuf.position(in_pkgBuf.limit());
in_pkgBuf.limit(in_pkgBuf.limit() + tmp_pkgBuf.limit());
in_pkgBuf.put(tmp_pkgBuf);
in_pkgBuf.reset();
in_appBuf.clear();
engResult = sslEngine.unwrap(in_pkgBuf, in_appBuf);
loop_count ++;
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK) {
Runnable run;
while ((run = sslEngine.getDelegatedTask()) != null) {
if (s_logger.isTraceEnabled()) {
s_logger.trace("SSL: Running delegated task!");
}
run.run();
}
} else if (hsStatus == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
throw new IOException("NOT a handshaking!");
}
if (engResult != null && engResult.getStatus() != SSLEngineResult.Status.OK) {
throw new IOException("Fail to handshake! " + engResult.getStatus());
}
if (engResult != null)
hsStatus = engResult.getHandshakeStatus();
else
hsStatus = sslEngine.getHandshakeStatus();
}
}
}