/* * Copyright 2014 NAVER Corp. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.navercorp.pinpoint.rpc.client; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.jboss.netty.channel.Channel; import org.jboss.netty.util.Timeout; import org.jboss.netty.util.Timer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.navercorp.pinpoint.rpc.ChannelWriteFailListenableFuture; import com.navercorp.pinpoint.rpc.DefaultFuture; import com.navercorp.pinpoint.rpc.FailureEventHandler; import com.navercorp.pinpoint.rpc.PinpointSocketException; import com.navercorp.pinpoint.rpc.ResponseMessage; import com.navercorp.pinpoint.rpc.packet.RequestPacket; import com.navercorp.pinpoint.rpc.packet.ResponsePacket; import com.navercorp.pinpoint.rpc.server.PinpointServer; /** * @author emeroad */ public class RequestManager { private final Logger logger = LoggerFactory.getLogger(this.getClass()); private final AtomicInteger requestId = new AtomicInteger(1); private final ConcurrentMap<Integer, DefaultFuture<ResponseMessage>> requestMap = new ConcurrentHashMap<Integer, DefaultFuture<ResponseMessage>>(); // Have to move Timer into factory? private final Timer timer; private final long defaultTimeoutMillis; public RequestManager(Timer timer, long defaultTimeoutMillis) { if (timer == null) { throw new NullPointerException("timer must not be null"); } if (defaultTimeoutMillis <= 0) { throw new IllegalArgumentException("defaultTimeoutMillis must greater than zero."); } this.timer = timer; this.defaultTimeoutMillis = defaultTimeoutMillis; } private FailureEventHandler createFailureEventHandler(final int requestId) { FailureEventHandler failureEventHandler = new FailureEventHandler() { @Override public boolean fireFailure() { DefaultFuture<ResponseMessage> future = removeMessageFuture(requestId); if (future != null) { // removed perfectly. return true; } return false; } }; return failureEventHandler; } private void addTimeoutTask(long timeoutMillis, DefaultFuture future) { if (future == null) { throw new NullPointerException("future"); } try { Timeout timeout = timer.newTimeout(future, timeoutMillis, TimeUnit.MILLISECONDS); future.setTimeout(timeout); } catch (IllegalStateException e) { // this case is that timer has been shutdown. That maybe just means that socket has been closed. future.setFailure(new PinpointSocketException("socket closed")) ; } } private int getNextRequestId() { return this.requestId.getAndIncrement(); } public void messageReceived(ResponsePacket responsePacket, String objectUniqName) { final int requestId = responsePacket.getRequestId(); final DefaultFuture<ResponseMessage> future = removeMessageFuture(requestId); if (future == null) { logger.warn("future not found:{}, objectUniqName:{}", responsePacket, objectUniqName); return; } else { logger.debug("responsePacket arrived packet:{}, objectUniqName:{}", responsePacket, objectUniqName); } ResponseMessage response = new ResponseMessage(); response.setMessage(responsePacket.getPayload()); future.setResult(response); } public void messageReceived(ResponsePacket responsePacket, PinpointServer pinpointServer) { final int requestId = responsePacket.getRequestId(); final DefaultFuture<ResponseMessage> future = removeMessageFuture(requestId); if (future == null) { logger.warn("future not found:{}, pinpointServer:{}", responsePacket, pinpointServer); return; } else { logger.debug("responsePacket arrived packet:{}, pinpointServer:{}", responsePacket, pinpointServer); } ResponseMessage response = new ResponseMessage(); response.setMessage(responsePacket.getPayload()); future.setResult(response); } public DefaultFuture<ResponseMessage> removeMessageFuture(int requestId) { return this.requestMap.remove(requestId); } public void messageReceived(RequestPacket requestPacket, Channel channel) { logger.error("unexpectedMessage received:{} address:{}", requestPacket, channel.getRemoteAddress()); } public ChannelWriteFailListenableFuture<ResponseMessage> register(RequestPacket requestPacket) { return register(requestPacket, defaultTimeoutMillis); } public ChannelWriteFailListenableFuture<ResponseMessage> register(RequestPacket requestPacket, long timeoutMillis) { // shutdown check final int requestId = getNextRequestId(); requestPacket.setRequestId(requestId); final ChannelWriteFailListenableFuture<ResponseMessage> future = new ChannelWriteFailListenableFuture<ResponseMessage>(timeoutMillis); final DefaultFuture old = this.requestMap.put(requestId, future); if (old != null) { throw new PinpointSocketException("unexpected error. old future exist:" + old + " id:" + requestId); } // when future fails, put a handle in order to remove a failed future in the requestMap. FailureEventHandler removeTable = createFailureEventHandler(requestId); future.setFailureEventHandler(removeTable); addTimeoutTask(timeoutMillis, future); return future; } public void close() { logger.debug("close()"); final PinpointSocketException closed = new PinpointSocketException("socket closed"); // Could you handle race conditions of "close" more precisely? // final Timer timer = this.timer; // if (timer != null) { // Set<Timeout> stop = timer.stop(); // for (Timeout timeout : stop) { // DefaultFuture future = (DefaultFuture)timeout.getTask(); // future.setFailure(closed); // } // } int requestFailCount = 0; for (Map.Entry<Integer, DefaultFuture<ResponseMessage>> entry : requestMap.entrySet()) { if(entry.getValue().setFailure(closed)) { requestFailCount++; } } this.requestMap.clear(); if (requestFailCount > 0) { logger.info("requestManager failCount:{}", requestFailCount); } } }