/* * JBoss, Home of Professional Open Source. * Copyright 2011, Red Hat, Inc., and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jboss.as.protocol.mgmt; import java.io.DataInput; import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; import org.jboss.as.protocol.StreamUtils; import org.jboss.as.protocol.logging.ProtocolLogger; import org.jboss.as.protocol.mgmt.support.ManagementChannelShutdownHandle; import org.jboss.remoting3.Channel; import org.jboss.remoting3.CloseHandler; import org.jboss.remoting3.MessageOutputStream; import org.jboss.threads.AsyncFuture; /** * Base class for {@link ManagementMessageHandler} implementations. * * @author Emanuel Muckenhuber */ public abstract class AbstractMessageHandler implements ManagementMessageHandler, ManagementChannelShutdownHandle, CloseHandler<Channel> { private static final ActiveOperation.CompletedCallback<?> NO_OP_CALLBACK = new ActiveOperation.CompletedCallback<Object>() { @Override public void completed(Object result) { // } @Override public void failed(Exception e) { // } @Override public void cancelled() { // } }; static <T> ActiveOperation.CompletedCallback<T> getDefaultCallback() { //noinspection unchecked return (ActiveOperation.CompletedCallback<T>) NO_OP_CALLBACK; } static <T> ActiveOperation.CompletedCallback<T> getCheckedCallback(final ActiveOperation.CompletedCallback<T> callback) { if(callback == null) { return getDefaultCallback(); } return callback; } private final ConcurrentMap<Integer, ActiveOperationImpl<?, ?>> activeRequests = new ConcurrentHashMap<> (16, 0.75f, Runtime.getRuntime().availableProcessors()); private final ManagementBatchIdManager operationIdManager = new ManagementBatchIdManager.DefaultManagementBatchIdManager(); private final ReentrantLock lock = new ReentrantLock(); private final Condition condition = lock.newCondition(); private final ExecutorService executorService; private final AtomicInteger requestID = new AtomicInteger(); private final Map<Integer, ActiveRequest<?, ?>> requests = new ConcurrentHashMap<Integer, ActiveRequest<?, ?>>(16, 0.75f, Runtime.getRuntime().availableProcessors()); // mutable variables, have to be guarded by the lock private int activeCount = 0; private volatile boolean shutdown = false; protected AbstractMessageHandler(final ExecutorService executorService) { if(executorService == null) { throw ProtocolLogger.ROOT_LOGGER.nullExecutor(); } this.executorService = executorService; } /** * Receive a notification that the channel was closed. * * This is used for the {@link ManagementClientChannelStrategy.Establishing} since it might use multiple channels. * * @param closed the closed resource * @param e the exception which occurred during close, if any */ public void handleChannelClosed(final Channel closed, final IOException e) { for(final ActiveOperationImpl<?, ?> activeOperation : activeRequests.values()) { if (activeOperation.getChannel() == closed) { // Only call cancel, to also interrupt still active threads activeOperation.getResultHandler().cancel(); } } } /** * Is shutdown. * * @return {@code true} if the shutdown method was called, {@code false} otherwise */ protected boolean isShutdown() { return shutdown; } /** * Prevent new active operations get registered. */ @Override public void shutdown() { lock.lock(); try { shutdown = true; } finally { lock.unlock(); } } /** * {@inheritDoc} */ @Override public void shutdownNow() { shutdown(); cancelAllActiveOperations(); } /** * Await the completion of all currently active operations. * * @param timeout the timeout * @param unit the time unit * @return {@code } false if the timeout was reached and there were still active operations * @throws InterruptedException */ @Override public boolean awaitCompletion(long timeout, TimeUnit unit) throws InterruptedException { long deadline = unit.toMillis(timeout) + System.currentTimeMillis(); lock.lock(); try { assert shutdown; while(activeCount != 0) { long remaining = deadline - System.currentTimeMillis(); if (remaining <= 0) { break; } condition.await(remaining, TimeUnit.MILLISECONDS); } boolean allComplete = activeCount == 0; if (!allComplete) { ProtocolLogger.ROOT_LOGGER.debugf("ActiveOperation(s) %s have not completed within %d %s", activeRequests.keySet(), timeout, unit); } return allComplete; } finally { lock.unlock(); } } /** * Get the executor * * @return the executor */ protected ExecutorService getExecutor() { return executorService; } /** * Get the request handler. * * @param header the request header * @return the request handler */ protected ManagementRequestHandler<?, ?> getRequestHandler(final ManagementRequestHeader header) { return getFallbackHandler(header); } /** * Validate whether the request can be handled. * * @param header the protocol header * @return the management request header * @throws IOException */ protected ManagementRequestHeader validateRequest(final ManagementProtocolHeader header) throws IOException { return (ManagementRequestHeader) header; } /** * Handle a message. * * @param channel the channel * @param input the message * @param header the management protocol header * @throws IOException */ @Override public void handleMessage(final Channel channel, final DataInput input, final ManagementProtocolHeader header) throws IOException { final byte type = header.getType(); if(type == ManagementProtocol.TYPE_RESPONSE) { // Handle response to local requests final ManagementResponseHeader response = (ManagementResponseHeader) header; final ActiveRequest<?, ?> request = requests.remove(response.getResponseId()); if(request == null) { ProtocolLogger.CONNECTION_LOGGER.noSuchRequest(response.getResponseId(), channel); safeWriteErrorResponse(channel, header, ProtocolLogger.ROOT_LOGGER.responseHandlerNotFound(response.getResponseId())); } else if(response.getError() != null) { request.handleFailed(response); } else { handleRequest(channel, input, header, request); } } else { // Handle requests (or other messages) try { final ManagementRequestHeader requestHeader = validateRequest(header); final ManagementRequestHandler<?, ?> handler = getRequestHandler(requestHeader); if(handler == null) { safeWriteErrorResponse(channel, header, ProtocolLogger.ROOT_LOGGER.responseHandlerNotFound(requestHeader.getBatchId())); } else { handleMessage(channel, input, requestHeader, handler); } } catch (Exception e) { safeWriteErrorResponse(channel, header, e); } } } /** * Execute a request. * * @param request the request * @param channel the channel * @param support the request support * @return the future result */ protected <T, A> AsyncFuture<T> executeRequest(final ManagementRequest<T, A> request, final Channel channel, final ActiveOperation<T, A> support) { assert support != null; updateChannelRef(support, channel); final Integer requestId = this.requestID.incrementAndGet(); final ActiveRequest<T, A> ar = new ActiveRequest<T, A>(support, request); requests.put(requestId, ar); final ManagementRequestHeader header = new ManagementRequestHeader(ManagementProtocol.VERSION, requestId, support.getOperationId(), request.getOperationType()); final ActiveOperation.ResultHandler<T> resultHandler = support.getResultHandler(); try { request.sendRequest(resultHandler, new ManagementRequestContextImpl<T, A>(support, channel, header, getExecutor())); } catch (Exception e) { resultHandler.failed(e); requests.remove(requestId); } return support.getResult(); } /** * Handle a message. * * @param channel the channel * @param message the message * @param header the protocol header * @param activeRequest the active request */ protected <T, A> void handleRequest(final Channel channel, final DataInput message, final ManagementProtocolHeader header, ActiveRequest<T, A> activeRequest) { handleMessage(channel, message, header, activeRequest.context, activeRequest.handler); } /** * Handle a message. * * @param channel the channel * @param message the message * @param header the protocol header * @param handler the request handler * @throws IOException */ protected <T, A> void handleMessage(final Channel channel, final DataInput message, final ManagementRequestHeader header, ManagementRequestHandler<T, A> handler) throws IOException { final ActiveOperation<T, A> support = getActiveOperation(header); if(support == null) { throw ProtocolLogger.ROOT_LOGGER.responseHandlerNotFound(header.getBatchId()); } handleMessage(channel, message, header, support, handler); } /** * Handle a message. * * @param channel the channel * @param message the message * @param header the protocol header * @param support the request support * @param handler the request handler */ protected <T, A> void handleMessage(final Channel channel, final DataInput message, final ManagementProtocolHeader header, final ActiveOperation<T, A> support, final ManagementRequestHandler<T, A> handler) { assert support != null; updateChannelRef(support, channel); final ActiveOperation.ResultHandler<T> resultHandler = support.getResultHandler(); try { handler.handleRequest(message, resultHandler, new ManagementRequestContextImpl<T, A>(support, channel, header, getExecutor())); } catch (Exception e) { resultHandler.failed(e); safeWriteErrorResponse(channel, header, e); } } @Override public void handleClose(final Channel closed, final IOException exception) { handleChannelClosed(closed, exception); } /** * Register an active operation. The operation-id will be generated. * * @param attachment the shared attachment * @return the active operation */ protected <T, A> ActiveOperation<T, A> registerActiveOperation(A attachment) { final ActiveOperation.CompletedCallback<T> callback = getDefaultCallback(); return registerActiveOperation(attachment, callback); } /** * Register an active operation. The operation-id will be generated. * * @param attachment the shared attachment * @param callback the completed callback * @return the active operation */ protected <T, A> ActiveOperation<T, A> registerActiveOperation(A attachment, ActiveOperation.CompletedCallback<T> callback) { return registerActiveOperation(null, attachment, callback); } /** * Register an active operation with a specific operation id. * * @param id the operation id * @param attachment the shared attachment * @return the created active operation * * @throws java.lang.IllegalStateException if an operation with the same id is already registered */ protected <T, A> ActiveOperation<T, A> registerActiveOperation(final Integer id, A attachment) { final ActiveOperation.CompletedCallback<T> callback = getDefaultCallback(); return registerActiveOperation(id, attachment, callback); } /** * Register an active operation with a specific operation id. * * @param id the operation id * @param attachment the shared attachment * @param callback the completed callback * @return the created active operation * * @throws java.lang.IllegalStateException if an operation with the same id is already registered */ protected <T, A> ActiveOperation<T, A> registerActiveOperation(final Integer id, A attachment, ActiveOperation.CompletedCallback<T> callback) { lock.lock(); try { // Check that we still allow registration // TODO WFCORE-199 distinguish client uses from server uses and limit this check to server uses // Using id==null may be one way to do this, but we need to consider ops that involve multiple requests // TODO WFCORE-845 consider using an IllegalStateException for this //assert ! shutdown; final Integer operationId; if(id == null) { // If we did not get an operationId, create a new one operationId = operationIdManager.createBatchId(); } else { // Check that the operationId is not already taken if(! operationIdManager.lockBatchId(id)) { throw ProtocolLogger.ROOT_LOGGER.operationIdAlreadyExists(id); } operationId = id; } final ActiveOperationImpl<T, A> request = new ActiveOperationImpl<T, A>(operationId, attachment, getCheckedCallback(callback), this); final ActiveOperation<?, ?> existing = activeRequests.putIfAbsent(operationId, request); if(existing != null) { throw ProtocolLogger.ROOT_LOGGER.operationIdAlreadyExists(operationId); } ProtocolLogger.ROOT_LOGGER.tracef("Registered active operation %d", operationId); activeCount++; // condition.signalAll(); return request; } finally { lock.unlock(); } } /** * Get an active operation. * * @param header the request header * @return the active operation, {@code null} if if there is no registered operation */ protected <T, A> ActiveOperation<T, A> getActiveOperation(final ManagementRequestHeader header) { return getActiveOperation(header.getBatchId()); } /** * Get the active operation. * * @param id the active operation id * @return the active operation, {@code null} if if there is no registered operation */ protected <T, A> ActiveOperation<T, A> getActiveOperation(final Integer id) { //noinspection unchecked return (ActiveOperation<T, A>) activeRequests.get(id); } /** * Cancel all currently active operations. * * @return a list of cancelled operations */ protected List<Integer> cancelAllActiveOperations() { final List<Integer> operations = new ArrayList<Integer>(); for(final ActiveOperationImpl<?, ?> activeOperation : activeRequests.values()) { activeOperation.asyncCancel(false); operations.add(activeOperation.getOperationId()); } return operations; } /** * Remove an active operation. * * @param id the operation id * @return the removed active operation, {@code null} if there was no registered operation */ protected <T, A> ActiveOperation<T, A> removeActiveOperation(Integer id) { final ActiveOperation<T, A> removed = removeUnderLock(id); if(removed != null) { for(final Map.Entry<Integer, ActiveRequest<?, ?>> requestEntry : requests.entrySet()) { final ActiveRequest<?, ?> request = requestEntry.getValue(); if(request.context == removed) { requests.remove(requestEntry.getKey()); } } } return removed; } private <T, A> ActiveOperation<T, A> removeUnderLock(final Integer id) { lock.lock(); try { final ActiveOperation<?, ?> removed = activeRequests.remove(id); if(removed != null) { ProtocolLogger.ROOT_LOGGER.tracef("Deregistered active operation %d", id); activeCount--; operationIdManager.freeBatchId(id); condition.signalAll(); } //noinspection unchecked return (ActiveOperation<T, A>) removed; } finally { lock.unlock(); } } /** * Safe write error response. * * @param channel the channel * @param header the request header * @param error the exception */ protected static void safeWriteErrorResponse(final Channel channel, final ManagementProtocolHeader header, final Throwable error) { if(header.getType() == ManagementProtocol.TYPE_REQUEST) { try { writeErrorResponse(channel, (ManagementRequestHeader) header, error); } catch(IOException ioe) { ProtocolLogger.ROOT_LOGGER.tracef(ioe, "failed to write error response for %s on channel: %s", header, channel); } } } /** * Write an error response. * * @param channel the channel * @param header the request * @param error the error * @throws IOException */ protected static void writeErrorResponse(final Channel channel, final ManagementRequestHeader header, final Throwable error) throws IOException { final ManagementResponseHeader response = ManagementResponseHeader.create(header, error); final MessageOutputStream output = channel.writeMessage(); try { writeHeader(response, output); output.close(); } finally { StreamUtils.safeClose(output); } } /** * Write the management protocol header. * * @param header the mgmt protocol header * @param os the output stream * @throws IOException */ protected static FlushableDataOutput writeHeader(final ManagementProtocolHeader header, final OutputStream os) throws IOException { final FlushableDataOutput output = FlushableDataOutputImpl.create(os); header.write(output); return output; } /** * Get a fallback handler. * * @param header the protocol header * @return the fallback handler */ protected static <T, A> ManagementRequestHandler<T, A> getFallbackHandler(final ManagementRequestHeader header) { return new ManagementRequestHandler<T, A>() { @Override public void handleRequest(final DataInput input, ActiveOperation.ResultHandler<T> resultHandler, ManagementRequestContext<A> context) throws IOException { final Exception error = ProtocolLogger.ROOT_LOGGER.noSuchResponseHandler(Integer.toHexString(header.getRequestId())); if(resultHandler.failed(error)) { safeWriteErrorResponse(context.getChannel(), context.getRequestHeader(), error); } } }; } private static void updateChannelRef(final ActiveOperation<?, ?> operation, Channel channel) { if (operation instanceof ActiveOperationImpl) { final ActiveOperationImpl<?, ?> a = (ActiveOperationImpl) operation; a.updateChannelRef(channel); } } private static class ActiveRequest<T, A> { private final ActiveOperation<T, A> context; private final ManagementResponseHandler<T, A> handler; ActiveRequest(ActiveOperation<T, A> context, ManagementResponseHandler<T, A> handler) { this.context = context; this.handler = handler; } protected void handleFailed(final ManagementResponseHeader header) { handler.handleFailed(header, context.getResultHandler()); } } }