package org.webpieces.nio.impl.ssl;
import java.io.InputStream;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import org.webpieces.util.logging.Logger;
import org.webpieces.util.logging.LoggerFactory;
import org.webpieces.data.api.BufferPool;
import org.webpieces.nio.api.SSLEngineFactory;
import org.webpieces.nio.api.SSLEngineFactoryWithHost;
import org.webpieces.nio.api.channels.Channel;
import org.webpieces.nio.api.channels.TCPChannel;
import org.webpieces.nio.api.exceptions.NioClosedChannelException;
import org.webpieces.nio.api.handlers.ConnectionListener;
import org.webpieces.nio.api.handlers.DataListener;
import org.webpieces.ssl.api.AsyncSSLEngine;
import org.webpieces.ssl.api.AsyncSSLFactory;
import org.webpieces.ssl.api.SslListener;
public class SslTCPChannel extends SslChannel implements TCPChannel {
private final static Logger log = LoggerFactory.getLogger(SslTCPChannel.class);
private AsyncSSLEngine sslEngine;
private SslTryCatchListener clientDataListener;
private final TCPChannel realChannel;
private SocketDataListener socketDataListener = new SocketDataListener();
private ConnectionListener conectionListener;
private CompletableFuture<Channel> sslConnectfuture;
private CompletableFuture<Channel> closeFuture;
private SslListener sslListener = new OurSslListener();
private SSLEngineFactory sslFactory;
private BufferPool pool;
private ClientHelloParser parser;
public SslTCPChannel(Function<SslListener, AsyncSSLEngine> function, TCPChannel realChannel) {
super(realChannel);
sslEngine = function.apply(sslListener );
this.realChannel = realChannel;
}
public SslTCPChannel(BufferPool pool, TCPChannel realChannel2, ConnectionListener connectionListener, SSLEngineFactory sslFactory) {
super(realChannel2);
this.pool = pool;
parser = new ClientHelloParser(pool);
this.realChannel = realChannel2;
this.conectionListener = connectionListener;
this.sslFactory = sslFactory;
}
@Override
public CompletableFuture<Channel> connect(SocketAddress addr, DataListener listener) {
clientDataListener = new SslTryCatchListener(listener);
CompletableFuture<Channel> future = realChannel.connect(addr, socketDataListener);
return future.thenCompose( c -> beginHandshake());
}
public SocketDataListener getSocketDataListener() {
return socketDataListener;
}
private CompletableFuture<Channel> beginHandshake() {
sslConnectfuture = new CompletableFuture<Channel>();
sslEngine.beginHandshake();
return sslConnectfuture;
}
@Override
public CompletableFuture<Channel> write(ByteBuffer b) {
if(b.remaining() == 0)
throw new IllegalArgumentException("You must pass in bytebuffers that contain data. b.remaining==0 in this buffer");
return sslEngine.feedPlainPacket(b).thenApply(v -> this);
}
@Override
public CompletableFuture<Channel> close() {
closeFuture = new CompletableFuture<>();
if(sslEngine == null) {
//this happens in the case where encryption link was not yet established(or even started for that matter)
//ie. HttpFrontend does a timeout on incoming client connections to the server so if someone connects to ssl, it
//times out and closes it
return realChannel.close();
}
sslEngine.close();
return closeFuture.thenApply(channel -> actuallyCloseSocket(channel, realChannel));
}
private Channel actuallyCloseSocket(Channel sslChannel, Channel realChannel) {
realChannel.close();
return sslChannel;
}
private class OurSslListener implements SslListener {
@Override
public void encryptedLinkEstablished() {
if(sslConnectfuture != null)
sslConnectfuture.complete(SslTCPChannel.this);
else {
CompletableFuture<DataListener> future = conectionListener.connected(SslTCPChannel.this, true);
if(!future.isDone())
conectionListener.failed(SslTCPChannel.this, new IllegalArgumentException("Client did not return a datalistener"));
try {
clientDataListener = new SslTryCatchListener(future.get());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@Override
public CompletableFuture<Void> packetEncrypted(ByteBuffer engineToSocketData) {
return realChannel.write(engineToSocketData).thenApply(c -> empty());
}
public Void empty() {
return null;
}
@Override
public void sendEncryptedHandshakeData(ByteBuffer engineToSocketData) {
try {
realChannel.write(engineToSocketData);
} catch(NioClosedChannelException e) {
log.info("Remote end closed before handshake was finished. (nothing we can do about that)");
}
//we don't care about future as we won't write anything out anyways until we get
//data back and we have not fired connected to client so he should also not be writing yet too
}
@Override
public void packetUnencrypted(ByteBuffer out) {
clientDataListener.incomingData(SslTCPChannel.this, out);
}
@Override
public void runTask(Runnable r) {
//we are multithreaded underneath anyways using SessionExecutor so
//we mine as well run this on same thread.
r.run();
}
@Override
public void closed(boolean clientInitiated) {
if(!clientInitiated)
clientDataListener.farEndClosed(SslTCPChannel.this);
else if(closeFuture == null)
throw new RuntimeException("bug, this should not be possible");
else
closeFuture.complete(SslTCPChannel.this);
}
}
private class SocketDataListener implements DataListener {
@Override
public void incomingData(Channel channel, ByteBuffer b) {
if(sslEngine == null) {
b = setupSSLEngine(channel, b);
if(b == null)
return; //not fully setup yet
}
sslEngine.feedEncryptedPacket(b);
}
private ByteBuffer setupSSLEngine(Channel channel, ByteBuffer b) {
try {
return setupSSLEngineImpl(channel, b);
} catch (SSLException e) {
throw new RuntimeException(e);
}
}
private ByteBuffer setupSSLEngineImpl(Channel channel, ByteBuffer b) throws SSLException {
if(sslFactory instanceof SSLEngineFactoryWithHost) {
SSLEngineFactoryWithHost sslFactoryWithHost = (SSLEngineFactoryWithHost) sslFactory;
ParseResult result = parser.fetchServerNamesIfEntirePacketAvailable(b);
List<String> sniServerNames = result.getNames();
if(sniServerNames.size() == 0) {
log.error("SNI servernames missing from client. channel="+channel.getRemoteAddress());
} else if(sniServerNames.size() > 1) {
log.error("SNI servernames are too many. names="+sniServerNames+" channel="+channel.getRemoteAddress());
}
String host = sniServerNames.get(0);
SSLEngine engine = sslFactoryWithHost.createSslEngine(host);
sslEngine = AsyncSSLFactory.create(realChannel+"", engine, pool, sslListener);
return result.getBuffer(); // return the full accumulated packet(which may just be the buffer passed in above)
} else {
SSLEngine engine = sslFactory.createSslEngine();
sslEngine = AsyncSSLFactory.create(realChannel+"", engine, pool, sslListener);
return b;
}
}
@Override
public void farEndClosed(Channel channel) {
if(clientDataListener != null)
clientDataListener.farEndClosed(SslTCPChannel.this);
}
@Override
public void failure(Channel channel, ByteBuffer data, Exception e) {
clientDataListener.failure(SslTCPChannel.this, data, e);
}
@Override
public void applyBackPressure(Channel channel) {
clientDataListener.applyBackPressure(SslTCPChannel.this);
}
@Override
public void releaseBackPressure(Channel channel) {
clientDataListener.releaseBackPressure(SslTCPChannel.this);
}
}
public SSLEngine createSslEngine() {
try {
// Create/initialize the SSLContext with key material
InputStream in = getClass().getClassLoader().getResourceAsStream("selfsigned.jks");
char[] passphrase = "password".toCharArray();
// First initialize the key and trust material.
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(in, passphrase);
SSLContext sslContext = SSLContext.getInstance("TLS");
//****************Server side specific*********************
// KeyManager's decide which key material to use.
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, passphrase);
sslContext.init(kmf.getKeyManagers(), null, null);
//****************Server side specific*********************
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(false);
return engine;
} catch(Exception e) {
throw new RuntimeException(e);
}
}
@Override
public boolean getKeepAlive() {
return realChannel.getKeepAlive();
}
@Override
public void setKeepAlive(boolean b) {
realChannel.setKeepAlive(b);
}
public DataListener getDataListener() {
return socketDataListener;
}
@Override
public boolean isSslChannel() {
return true;
}
}