/* * JBoss, Home of Professional Open Source. * * Copyright 2012 Red Hat, Inc. and/or its affiliates, and individual * contributors as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.xnio.ssl.mock; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import org.jboss.logging.Logger; import org.jmock.Expectations; import org.jmock.Mockery; import org.xnio.Buffers; /** * Mocks an SSLEngine for test purposes.<p> * The handshaking behavior of the mock is defined by a sequence of {@link HandshakeAction}s. On every request to wrap * or unwrap, this mock will take one of the predetermined actions, following the order those actions are {@link * #setHandshakeActions(HandshakeAction...) defined}. If an action cannot be performed at current wrap/unwrap request * (for example, take the case of a {@link HandshakeAction#NEED_WRAP NEED_WRAP} on an unwrap request), the resulting * {@code SSLEngineResult} will point that out and the action will be postponed to the next wrap/unwrap invocation, * until it can be executed. Only then, this mock will move on to the next handshake action. * <p> * Once this mock reaches the end of the handshake action list, it will no longer perform any handshaking, and will * always fulfill any wrap/unwrap request without failures. If this mock ever unwraps a {@link CLOSE_MSG}, it will * close itself and indicate in the future {@code SSLEngineResult}s it needs to wrap a {@link CLOSE_MSG} until requested * to wrap. * <p> * Invoking {@link #closeInbound} or {@link #closeOutbound} will also have the effect of closing this engine for both * wrap and unwrap. Notice that after {@code closeOutbound} is invoked, this engine will request to unwrap until it * unwraps a {@link CLOSE_MSG}, unless the engine has already done so. * <p> * Once closed, no handshake action is taken after this engine is closed, regardless of whether it has executed all * handshake actions or not. The only exception to this rule is that the engine can request to wrap, or unwrap a {@link * CLOSE_MSG} when that happens. * <p>To mimic wrap and unwrap, this mock uses a wrap/unwrap register. This register works like a map, containing what * is the wrapped equivalent of an unwrapped message and vice-versa. To use this register, you just have to call {@link * #addWrapEntry(String, String)} for every unwrapped/wrapped pair of messages you want to add. That way, when requested * to wrap a message this mock will search for the unwrapped message in its register. If the register contains the * message to be wrapped, this mock will use the wrapped version of it to write that message; if not, the engine will * simply copy the message. Likewise, if a message needs to be unwrapped, this mock will check if this message has a * corresponding wrapped version in the register. If it does, it will use the wrapped message. If not, it will simply * copy the message when wrapping. Each wrap entry can be defined by calling {@link #addWrapEntry(String, String)}. * * @author <a href="mailto:flavia.rainone@jboss.com">Flavia Rainone</a> * */ public class SSLEngineMock extends SSLEngine { private static final Logger log = Logger.getLogger("TEST"); // every wrapped handshake message generated by this engine is the result of wrapping HANDSHAKE_MSG public static final String HANDSHAKE_MSG = "[handshake data]"; // every wrapped close message generated by this engine is the resujlt of wrapping CLOSE_MSG public static final String CLOSE_MSG = "[close]"; // mockery, to generate task mocks protected final Mockery mockery; // marks the index of current action private int actionIndex; // the sequence of handshake actions private HandshakeAction[] actions; // mapped wrapper, the unwrap/wrap register, performs wrap and unwrap private final MappedWrapper wrapper; // indicates if this engine is closed private boolean closed = false; // indicates if this engine has sent a wrapped CLOSE_MSG private boolean closeMessageWrapped = false; // indicates if this engine has unwrapped a CLOSE_MSG private boolean closeMessageUnwrapped = false; // supported cipher suites private String[] supportedCipherSuites = new String[0]; // supported protocols private String[] supportedProtocols = new String[0]; // enabled cipher suites private String[] enabledCipherSuites = null; // enabled protocols private String[] enabledProtocols = null; // indicates whether session creation is enabled private boolean enableSessionCreation = false; // client mode enabled private boolean useClientMode = false; // need client auth private boolean needClientAuth = false; // want client auth private boolean wantClientAuth = false; /** * Determines the handshake action this mock should take when requested to wrap/unwrap a message */ public static enum HandshakeAction { /** * Will not proceed until client requests this engine to wrap data. * The generated wrapped data will be {@link SSLEngineMock#HANDSHAKE_MSG}. */ NEED_WRAP, /** * Will not proceed until client requests this engine to unwrap non-empty data. * This engine expects to read {@link SSLEngineMock#HANDSHAKE_MSG}. */ NEED_UNWRAP, /** * Will not perform any related handshake action when requested to wrap/unwrap data. */ PERFORM_REQUESTED_ACTION, /** * A task is needed to be executed so handshake can proceed. */ NEED_TASK, /** * A task is needed to be executed so handshake can proceed. This mock will provide a faulty task, * one that will throw an exception when requested to execute. * TODO this action is not supported right now */ NEED_FAULTY_TASK, /** * Engine will finish the handshake. */ FINISH}; /** * Constructor. * * @param mockery mockery to create mocks for internal use */ public SSLEngineMock(Mockery mockery) { this.mockery = mockery; this.wrapper = new MappedWrapper(); this.actions = new HandshakeAction[0]; this.actionIndex = 0; } /** * Sets the handshake actions this engine will take when requested to wrap or unwrap a message. * <p> * Once all requested actions have been successfully performed, this mock will simply perform any requested action * without any handshaking, which is equivalent to executing a * {@link HandshakeAction#PERFORM_REQUESTED_ACTION PERFORM_REQUESTED_ACTION}. * * @param actions the actions that define the handshake behavior of this mock when it receives a request to unwrap * or wrap new data. */ public void setHandshakeActions(HandshakeAction... actions) { this.actions = actions; } /** * A wrap entry is a pair of texts representing an unwrapped data and its wrapped equivalent. * Once this wrap entry is added to this mock, this mock will start writing {@code wrappedData} whenever it is * requested to wrap {@code unwrappedData}, and it will write {@code unwrappedData} whenever requested to unwrap * {@code wrappedData}. * * @param unwrappedData unwrapped data * @param wrappedData the wrapped equivalent of {@code unwrappedData} */ public void addWrapEntry(String unwrappedData, String wrappedData) { wrapper.put(unwrappedData, wrappedData); } /** * Does nothing. */ @Override public void beginHandshake() throws SSLException { // do nothing } /** * Marks this engine as closed. */ @Override public synchronized void closeInbound() throws SSLException { closed = true; } /** * Marks this engine as closed. */ @Override public synchronized void closeOutbound() { closed = true; } // count used to differentiate delegated tasks (a requirement of JMock) private int taskCount = 0; /** * Returns:<ul> * <li>{@code null} if this engine is marked as closed, or if current action is not one of * {@link HandshakeAction#NEED_TASK NEED_TASK}, {@link HandshakeAction#NEED_FAULTY_TASK NEED_FAULTY_TASK}. * <li> a task mock if current action is {@link HandshakeAction#NEED_TASK NEED_TASK}. The mock expects to have its method * {@code run()} invoked exactly once by client * <li> a task mock whose {@code run()} method will throw an exception, if current action is {@link * HandshakeAction#NEED_FAULTY_TASK NEED_FAULTY_TASK}. The mock expects to have its method {@code run()} invoked exactly once. * </ul> * */ @SuppressWarnings("incomplete-switch") @Override public synchronized Runnable getDelegatedTask() { if (!closed && actionIndex < actions.length) { switch (actions[actionIndex]) { case NEED_TASK: { actionIndex ++; synchronized (mockery){ final Runnable task = mockery.mock(Runnable.class, "RunnableMock" + taskCount++); mockery.checking(new Expectations() {{ oneOf(task).run(); }}); return task; } } case NEED_FAULTY_TASK: synchronized (mockery){ final Runnable task = mockery.mock(Runnable.class, "RunnableFaultyMock" + taskCount++); mockery.checking(new Expectations() {{ oneOf(task).run(); will(throwException(new RuntimeException())); }}); return task; } } } return null; } private int sessionCount = 0; @Override public SSLSession getSession() { synchronized (mockery) { final SSLSession sessionMock = mockery.mock(SSLSession.class, "Session" + sessionCount ++); if (sessionCount == 1) { mockery.checking(new Expectations() {{ oneOf(sessionMock).getPacketBufferSize(); will(returnValue(16916)); oneOf(sessionMock).getApplicationBufferSize(); will(returnValue(16921)); }}); } else { mockery.checking(new Expectations() {{ allowing(sessionMock).getPacketBufferSize(); will(returnValue(16916)); allowing(sessionMock).getApplicationBufferSize(); will(returnValue(16921)); }}); } return sessionMock; } } // this method avoids duplicate actionIndex increments for the same action private void actionAccountedFor(HandshakeAction action, int index) { synchronized(this) { if (actionIndex >= actions.length || closed) { return; } if (actionIndex == index && actions[actionIndex] == action) { actionIndex ++; } } } @Override public synchronized HandshakeStatus getHandshakeStatus() { final int currentIndex; synchronized (this) { if (closed) { if (!closeMessageWrapped) { return HandshakeStatus.NEED_WRAP; } if (!closeMessageUnwrapped) { return HandshakeStatus.NEED_UNWRAP; } } if (closed || actionIndex > actions.length) { return HandshakeStatus.NOT_HANDSHAKING; } currentIndex = actionIndex; } if(currentIndex >= actions.length) { return HandshakeStatus.NOT_HANDSHAKING; } log.debug("Current engine mock action: " + actions[currentIndex]); switch(actions[currentIndex]) { case NEED_TASK: case NEED_FAULTY_TASK: return HandshakeStatus.NEED_TASK; case NEED_WRAP: return HandshakeStatus.NEED_WRAP; case NEED_UNWRAP: return HandshakeStatus.NEED_UNWRAP; case PERFORM_REQUESTED_ACTION: return HandshakeStatus.NOT_HANDSHAKING; case FINISH: return HandshakeStatus.FINISHED; default: throw new IllegalStateException("Unexpected handshake action: " + actions[currentIndex]); } } private SSLEngineResult logOperation(String operation, SSLEngineResult result) { log.debugf("SSLEngineMock.%s returned [%s, %s, consumed:%d, produced:%d]", operation, result.getStatus(), result.getHandshakeStatus(), result.bytesConsumed(), result.bytesProduced()); return result; } private SSLEngineResult logUnwrap(SSLEngineResult result) { return logOperation("unwrap", result); } private SSLEngineResult logWrap(SSLEngineResult result) { return logOperation("wrap", result); } @Override public synchronized SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) throws SSLException { for (int i = offset; i < length; i++) { if (dsts[i].hasRemaining()) { break; } else if (i == length -1) { return logUnwrap(new SSLEngineResult(SSLEngineResult.Status.BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0)); } } if (closed && closeMessageUnwrapped) { return logUnwrap(new SSLEngineResult(Status.CLOSED, HandshakeStatus.NOT_HANDSHAKING, 0, 0)); } final HandshakeStatus status; final int currentActionIndex; synchronized (this) { status = getHandshakeStatus(); currentActionIndex = actionIndex; } final SSLEngineResult result; switch(status) { case FINISHED: actionAccountedFor(HandshakeAction.FINISH, currentActionIndex); result = new SSLEngineResult(Status.OK, HandshakeStatus.FINISHED, 30, 0); break; case NEED_TASK: result = new SSLEngineResult(Status.OK, HandshakeStatus.NEED_TASK, 0, 0); break; case NEED_UNWRAP: result = wrapper.unwrap(dsts, offset, length, src, true, currentActionIndex); break; case NEED_WRAP: if (closed) { result = new SSLEngineResult(Status.CLOSED, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0); } else { result = new SSLEngineResult(Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0); } break; case NOT_HANDSHAKING: actionAccountedFor(HandshakeAction.PERFORM_REQUESTED_ACTION, currentActionIndex); result = wrapper.unwrap(dsts, offset, length, src, false, currentActionIndex); break; default: throw new IllegalStateException("Unexpected handshake status: " + getHandshakeStatus()); } logUnwrap(result); return result; } @Override public synchronized SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) throws SSLException { if (dst.position() > 0) { return logWrap(new SSLEngineResult(SSLEngineResult.Status.BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0)); } final HandshakeStatus status; final int currentActionIndex; synchronized (this) { status = getHandshakeStatus(); currentActionIndex = actionIndex; } final SSLEngineResult result; switch(status) { case FINISHED: { actionAccountedFor(HandshakeAction.FINISH, currentActionIndex); result = new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.FINISHED, 0, 0); break; } case NEED_TASK: result = new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_TASK, 0, 0); break; case NEED_UNWRAP: result = new SSLEngineResult(closed? SSLEngineResult.Status.CLOSED: SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0); break; case NEED_WRAP: result = wrapper.wrap(dst, srcs, offset, length, true, currentActionIndex); break; case NOT_HANDSHAKING: { if (closed) { return new SSLEngineResult(SSLEngineResult.Status.CLOSED, SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, 0, 0); } actionAccountedFor(HandshakeAction.PERFORM_REQUESTED_ACTION, currentActionIndex); result = wrapper.wrap(dst, srcs, offset, length, false, currentActionIndex); break; } default: throw new IllegalStateException("Unexpected handshake status: " + getHandshakeStatus()); } logWrap(result); return result; } private final class MappedWrapper { private final Map<String, String> wrapMap = new HashMap<String, String>(); private final Map<String, String> unwrapMap = new HashMap<String, String>(); public void put(String unwrapped, String wrapped) { wrapMap.put(unwrapped, wrapped); // avoid use of wrapped text as it if it were wrapped if (!wrapMap.containsKey(wrapped)) { wrapMap.put(wrapped, ""); } unwrapMap.put(wrapped, unwrapped); // avoid use of unwrapped text as if it were wrapped on read if (!unwrapMap.containsKey(unwrapped)) { unwrapMap.put(unwrapped, ""); } } public SSLEngineResult unwrap(ByteBuffer[] dsts, int offset, int length, ByteBuffer src, boolean needUnwrap,int actionIndex) { if (!src.hasRemaining()) { return new SSLEngineResult(closed && closeMessageUnwrapped? Status.CLOSED: Status.BUFFER_UNDERFLOW, needUnwrap? HandshakeStatus.NEED_UNWRAP: HandshakeStatus.NOT_HANDSHAKING, 0, 0); } Status okStatus = closed && closeMessageUnwrapped? Status.CLOSED: Status.OK; // amount of bytes available at src int initialSrcRemaining = src.remaining(); int bytesProduced = 0; while (src.hasRemaining()) { String unwrapped = unwrapBytes(dsts, offset, length, src, needUnwrap); int bytesConsumed = initialSrcRemaining - src.remaining(); if (unwrapped == null) { if (bytesProduced == 0) { return new SSLEngineResult(Status.BUFFER_OVERFLOW, HandshakeStatus.NEED_UNWRAP, 0, 0); } return new SSLEngineResult(okStatus, HandshakeStatus.NEED_UNWRAP, bytesConsumed, bytesProduced); } // it means handshake message was found when it is not expected, ignore it and return everything that // was read up until now if (unwrapped.length() == 0) { if (bytesConsumed > 0) { actionAccountedFor(HandshakeAction.PERFORM_REQUESTED_ACTION, actionIndex); } return new SSLEngineResult(okStatus, getHandshakeStatus(), bytesConsumed, bytesProduced); } // it means we are on a needUnwrap and we found the handshake message we are seeking if (unwrapped.equals(HANDSHAKE_MSG)) { actionAccountedFor(HandshakeAction.NEED_UNWRAP, actionIndex); // move to next action, NEED_UNWRAP action is finally accomplished return new SSLEngineResult(okStatus, getHandshakeStatus(), bytesConsumed, bytesProduced); } if (unwrapped.equals(CLOSE_MSG)) { closed = true; closeMessageUnwrapped = true; return new SSLEngineResult(Status.CLOSED, getHandshakeStatus(), bytesConsumed, bytesProduced); } bytesProduced += unwrapped.length(); } int bytesConsumed = initialSrcRemaining - src.remaining(); if (bytesConsumed > 0 && !needUnwrap) { actionAccountedFor(HandshakeAction.PERFORM_REQUESTED_ACTION, actionIndex); } return new SSLEngineResult(okStatus, getHandshakeStatus(), bytesConsumed, bytesProduced); } private String unwrapBytes(ByteBuffer[] dsts, int offset, int length, ByteBuffer src, boolean readHandshakeMsg) { // define unwrapped data to be written to dsts String wrapped = new String(Buffers.take(src), StandardCharsets.ISO_8859_1); int wrappedEndIndex = wrapped.length(); int wrappedLeftOver = -1; while (wrappedEndIndex > 0 && !unwrapMap.containsKey(wrapped.substring(0, wrappedEndIndex))) { wrappedLeftOver = wrappedEndIndex --; } // undo the reading of data that won't be used now if (wrappedLeftOver != -1 && wrappedEndIndex > 0) { src.position(src.position() - (wrapped.length() - wrappedEndIndex)); wrapped = wrapped.substring(0, wrappedEndIndex); } else { int msgIndex; if ((msgIndex = wrapped.indexOf(HANDSHAKE_MSG)) != -1) { if (msgIndex == 0) { src.position(src.position() - (wrapped.length() - HANDSHAKE_MSG.length())); wrapped = wrapped.substring(0, HANDSHAKE_MSG.length()); } else { src.position(src.position() - (wrapped.length() - msgIndex)); wrapped = wrapped.substring(0, msgIndex); } } if ((msgIndex = wrapped.indexOf(CLOSE_MSG)) != -1) { if (msgIndex == 0) { src.position(src.position() - (wrapped.length() - CLOSE_MSG.length())); wrapped = wrapped.substring(0, CLOSE_MSG.length()); } else { src.position(src.position() - (wrapped.length() - msgIndex)); wrapped = wrapped.substring(0, msgIndex); } } } String unwrapped = unwrapMap.containsKey(wrapped)? unwrapMap.get(wrapped): wrapped; if (unwrapped.equals(HANDSHAKE_MSG) && !readHandshakeMsg) { src.position(src.position() - wrapped.length()); return ""; } if (!unwrapped.equals(CLOSE_MSG) && !unwrapped.equals(HANDSHAKE_MSG)) { if (CLOSE_MSG.startsWith(unwrapped) || HANDSHAKE_MSG.startsWith(unwrapped)) { src.position(0); return null; } // check if there is enough space to write unwrapped data, if not, do not write if (Buffers.remaining(dsts, offset, length) < unwrapped.length()) { src.position(src.position() - wrapped.length()); return null; } // copy data to dsts Buffers.copy(dsts, offset, length, ByteBuffer.wrap(unwrapped.getBytes(StandardCharsets.ISO_8859_1))); } return unwrapped; } public SSLEngineResult wrap(ByteBuffer dst, ByteBuffer[] srcs, int offset, int length, boolean needWrap, int actionIndex) { if (needWrap) { actionAccountedFor(HandshakeAction.NEED_WRAP, actionIndex); // a valid needWrapActionIndex indicates that we musts wrap a handshake message if (closed) { synchronized(SSLEngineMock.this) { closeMessageWrapped = true; } return wrapMessage(dst, CLOSE_MSG, Status.CLOSED); } return wrapMessage(dst, HANDSHAKE_MSG, Status.OK); } int dstInitialRemaining = dst.remaining(); int bytesConsumed = wrapBytes(dst, srcs, offset, length); int bytesProduced = dstInitialRemaining - dst.remaining(); if (bytesConsumed > 0) { actionAccountedFor(HandshakeAction.PERFORM_REQUESTED_ACTION, actionIndex); } if (closed) { return new SSLEngineResult(Status.CLOSED, HandshakeStatus.NOT_HANDSHAKING, bytesConsumed, bytesProduced); } if (bytesConsumed == 0) { for (int i = offset; i < length; i++) { if (srcs[i].hasRemaining()) { return new SSLEngineResult(Status.BUFFER_OVERFLOW, HandshakeStatus.NOT_HANDSHAKING, bytesConsumed, bytesProduced); } } } return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, bytesConsumed, bytesProduced); } private SSLEngineResult wrapMessage(ByteBuffer dst, String msg, Status okayStatus) { String wrappedMessage = wrapMap.containsKey(msg)? wrapMap.get(msg): msg; if (dst.remaining() < wrappedMessage.length()) { return new SSLEngineResult(Status.BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); } Buffers.putModifiedUtf8(dst, wrappedMessage); return new SSLEngineResult(okayStatus, getHandshakeStatus(), 0, wrappedMessage.length()); } public int wrapBytes(ByteBuffer dst, ByteBuffer[] srcs, int offset, int length) { int totalLength = 0; int srcsLength = offset + length; for (int i = offset; i < srcsLength && dst.hasRemaining(); i++) { StringBuilder unwrappedBuilder = new StringBuilder(); Buffers.readModifiedUtf8Line(srcs[i], unwrappedBuilder); int wrappedLength = wrapBytes(dst, unwrappedBuilder.toString()); if (wrappedLength == 0 && unwrappedBuilder.length() > 0) { srcs[i].position(srcs[i].position() - unwrappedBuilder.length()); break; } totalLength += wrappedLength; } return totalLength; } public int wrapBytes(ByteBuffer dst, String src) { String wrapped = wrapMap.containsKey(src)? wrapMap.get(src): src; if (dst.remaining() < wrapped.length()) { return 0; } Buffers.putModifiedUtf8(dst, wrapped); return src.length(); } } @Override public boolean getEnableSessionCreation() { return enableSessionCreation; } @Override public String[] getEnabledCipherSuites() { return enabledCipherSuites; } @Override public String[] getEnabledProtocols() { return enabledProtocols; } @Override public boolean getNeedClientAuth() { return needClientAuth; } @Override public String[] getSupportedCipherSuites() { return supportedCipherSuites; } public void setSupportedCipherSuites(String... supportedCipherSuites) { this.supportedCipherSuites = supportedCipherSuites; } @Override public String[] getSupportedProtocols() { return supportedProtocols; } public void setSupportedProtocols(String... supportedProtocols) { this.supportedProtocols = supportedProtocols; } @Override public boolean getUseClientMode() { return useClientMode; } @Override public boolean getWantClientAuth() { return wantClientAuth; } @Override public boolean isInboundDone() { return false; } @Override public boolean isOutboundDone() { return closeMessageWrapped; } @Override public void setEnableSessionCreation(boolean flag) { enableSessionCreation = flag; } @Override public void setEnabledCipherSuites(String[] suites) { enabledCipherSuites = suites; } @Override public void setEnabledProtocols(String[] protocols) { enabledProtocols = protocols; } @Override public void setNeedClientAuth(boolean need) { needClientAuth = need; } @Override public void setUseClientMode(boolean mode) { useClientMode = mode; } @Override public void setWantClientAuth(boolean want) { wantClientAuth = want; } }