/** * Copyright 2016 LinkedIn Corp. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ package com.github.ambry.network; import com.codahale.metrics.Counter; import com.codahale.metrics.MetricRegistry; import com.github.ambry.config.SSLConfig; import com.github.ambry.utils.Utils; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import java.nio.channels.Channels; import java.util.ArrayList; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * A blocking channel that is used to communicate with a server using SSL */ public class SSLBlockingChannel extends BlockingChannel { private SSLSocket sslSocket = null; private final SSLSocketFactory sslSocketFactory; private final SSLConfig sslConfig; public final Counter sslClientHandshakeErrorCount; public final Counter sslClientHandshakeCount; public SSLBlockingChannel(String host, int port, MetricRegistry registry, int readBufferSize, int writeBufferSize, int readTimeoutMs, int connectTimeoutMs, SSLSocketFactory sslSocketFactory, SSLConfig sslConfig) { super(host, port, readBufferSize, writeBufferSize, readTimeoutMs, connectTimeoutMs); if (sslSocketFactory == null) { throw new IllegalArgumentException("sslSocketFactory is null when creating SSLBlockingChannel"); } this.sslSocketFactory = sslSocketFactory; this.sslConfig = sslConfig; sslClientHandshakeErrorCount = registry.counter(MetricRegistry.name(SSLBlockingChannel.class, "SslClientHandshakeErrorCount")); sslClientHandshakeCount = registry.counter(MetricRegistry.name(SSLBlockingChannel.class, "SslClientHandshakeCount")); } @Override public void connect() throws IOException { synchronized (lock) { if (!connected) { Socket socket = new Socket(); socket.setSoTimeout(readTimeoutMs); socket.setKeepAlive(true); socket.setTcpNoDelay(true); if (readBufferSize > 0) { socket.setReceiveBufferSize(readBufferSize); } if (writeBufferSize > 0) { socket.setSendBufferSize(writeBufferSize); } socket.connect(new InetSocketAddress(host, port), connectTimeoutMs); sslSocket = (SSLSocket) sslSocketFactory.createSocket(socket, host, port, true); ArrayList<String> protocolsList = Utils.splitString(sslConfig.sslEnabledProtocols, ","); if (protocolsList != null && protocolsList.size() > 0) { String[] enabledProtocols = protocolsList.toArray(new String[protocolsList.size()]); sslSocket.setEnabledProtocols(enabledProtocols); } ArrayList<String> cipherSuitesList = Utils.splitString(sslConfig.sslCipherSuites, ","); if (cipherSuitesList != null && cipherSuitesList.size() > 0 && !(cipherSuitesList.size() == 1 && cipherSuitesList.get(0).equals(""))) { String[] cipherSuites = cipherSuitesList.toArray(new String[cipherSuitesList.size()]); sslSocket.setEnabledCipherSuites(cipherSuites); } // handshake in a blocking way try { sslSocket.startHandshake(); sslClientHandshakeCount.inc(); } catch (IOException e) { sslClientHandshakeErrorCount.inc(); throw e; } writeChannel = Channels.newChannel(sslSocket.getOutputStream()); readChannel = sslSocket.getInputStream(); connected = true; logger.debug( "Created socket with SO_TIMEOUT = {} (requested {}), SO_RCVBUF = {} (requested {}), SO_SNDBUF = {} (requested {})", sslSocket.getSoTimeout(), readTimeoutMs, sslSocket.getReceiveBufferSize(), readBufferSize, sslSocket.getSendBufferSize(), writeBufferSize); } } } @Override public void disconnect() { synchronized (lock) { try { if (connected || sslSocket != null) { // closing the main socket channel *should* close the read channel // but let's do it to be sure. sslSocket.close(); if (readChannel != null) { readChannel.close(); readChannel = null; } if (writeChannel != null) { writeChannel.close(); writeChannel = null; } sslSocket = null; connected = false; } } catch (Exception e) { logger.error("error while disconnecting {}", e); } } } }