/*
* Copyright (c) 2008-2012, Hazel Bilisim Ltd. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.nio.ssl;
import com.hazelcast.nio.DefaultSocketChannelWrapper;
import javax.net.ssl.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
public class SSLSocketChannelWrapper extends DefaultSocketChannelWrapper {
private final ByteBuffer in;
private final ByteBuffer out;
private final ByteBuffer cTOs; // "reliable" write transport
private final ByteBuffer sTOc; // "reliable" read transport
private final SSLEngine sslEngine;
private SSLEngineResult sslEngineResult;
public SSLSocketChannelWrapper(SSLContext sslContext, SocketChannel sc, boolean client) throws Exception {
super(sc);
sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(client);
sslEngine.setEnableSessionCreation(true);
SSLSession session = sslEngine.getSession();
in = ByteBuffer.allocate(64 * 1024);
int appBufferMax = session.getApplicationBufferSize();
int netBufferMax = session.getPacketBufferSize();
out = ByteBuffer.allocate(appBufferMax);
cTOs = ByteBuffer.allocate(netBufferMax);
sTOc = ByteBuffer.allocate(netBufferMax);
write(out);
while (sslEngineResult.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED) {
if (sslEngineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
sTOc.clear();
while (socketChannel.read(sTOc) < 1) {
Thread.sleep(50);
}
sTOc.flip();
unwrap(sTOc);
if (sslEngineResult.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED) {
out.clear();
write(out);
}
} else if (sslEngineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
out.clear();
write(out);
} else {
Thread.sleep(500);
}
}
in.clear();
in.flip();
}
private ByteBuffer unwrap(ByteBuffer b) throws SSLException {
in.clear();
while (b.hasRemaining()) {
sslEngineResult = sslEngine.unwrap(b, in);
if (sslEngineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
task.run();
}
} else if (sslEngineResult.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED
|| sslEngineResult.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
return in;
}
}
return in;
}
public int write(ByteBuffer input) throws IOException {
sslEngineResult = sslEngine.wrap(input, cTOs);
cTOs.flip();
int written = socketChannel.write(cTOs);
if (cTOs.hasRemaining()) {
cTOs.compact();
} else {
cTOs.clear();
}
return written;
}
public int read(ByteBuffer output) throws IOException {
int readBytesCount = 0;
int limit;
if (in.hasRemaining()) {
limit = Math.min(in.remaining(), output.remaining());
for (int i = 0; i < limit; i++) {
output.put(in.get());
readBytesCount++;
}
return readBytesCount;
}
if (sTOc.hasRemaining()) {
unwrap(sTOc);
in.flip();
limit = Math.min(in.limit(), output.remaining());
for (int i = 0; i < limit; i++) {
output.put(in.get());
readBytesCount++;
}
if (sslEngineResult.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW) {
sTOc.clear();
sTOc.flip();
return readBytesCount;
}
}
if (sTOc.hasRemaining()) {
sTOc.compact();
} else {
sTOc.clear();
}
if (socketChannel.read(sTOc) == -1) {
sTOc.clear();
sTOc.flip();
return -1;
}
sTOc.flip();
unwrap(sTOc);
in.flip();
limit = Math.min(in.limit(), output.remaining());
for (int i = 0; i < limit; i++) {
output.put(in.get());
readBytesCount++;
}
return readBytesCount;
}
public void close() throws IOException {
sslEngine.closeOutbound();
try {
out.clear();
write(out);
} catch (Exception ignored) {
}
socketChannel.close();
}
@Override
public long read(ByteBuffer[] byteBuffers, int i, int i1) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long write(ByteBuffer[] byteBuffers, int i, int i1) throws IOException {
throw new UnsupportedOperationException();
}
}