/*
* Copyright 2015 NAVER Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.navercorp.pinpoint.web.websocket;
import com.navercorp.pinpoint.common.util.PinpointThreadFactory;
import com.navercorp.pinpoint.rpc.util.ClassUtils;
import com.navercorp.pinpoint.rpc.util.MapUtils;
import com.navercorp.pinpoint.web.security.ServerMapDataFilter;
import com.navercorp.pinpoint.web.service.AgentService;
import com.navercorp.pinpoint.web.util.SimpleOrderedThreadPool;
import com.navercorp.pinpoint.web.websocket.message.PinpointWebSocketMessage;
import com.navercorp.pinpoint.web.websocket.message.PinpointWebSocketMessageConverter;
import com.navercorp.pinpoint.web.websocket.message.PinpointWebSocketMessageType;
import com.navercorp.pinpoint.web.websocket.message.PongMessage;
import com.navercorp.pinpoint.web.websocket.message.RequestMessage;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* @Author Taejin Koo
*/
public class ActiveThreadCountHandler extends TextWebSocketHandler implements PinpointWebSocketHandler {
public static final String APPLICATION_NAME_KEY = "applicationName";
private static final String HEALTH_CHECK_WAIT_KEY = "pinpoint.healthCheck.wait";
static final String API_ACTIVE_THREAD_COUNT = "activeThreadCount";
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private final Object lock = new Object();
private final AgentService agentService;
private final List<WebSocketSession> sessionRepository = new CopyOnWriteArrayList<>();
private final Map<String, PinpointWebSocketResponseAggregator> aggregatorRepository = new ConcurrentHashMap<>();
private final PinpointWebSocketMessageConverter messageConverter = new PinpointWebSocketMessageConverter();
private static final String DEFAULT_REQUEST_MAPPING = "/agent/activeThread";
private final String requestMapping;
private final AtomicBoolean onTimerTask = new AtomicBoolean(false);
private SimpleOrderedThreadPool webSocketFlushExecutor;
private java.util.Timer flushTimer;
private static final long DEFAULT_FLUSH_DELAY = 1000;
private final long flushDelay;
private java.util.Timer healthCheckTimer;
private static final long DEFAULT_HEALTH_CHECk_DELAY = 60 * 1000;
private final long healthCheckDelay;
private java.util.Timer reactiveTimer;
@Autowired(required=false)
ServerMapDataFilter serverMapDataFilter;
public ActiveThreadCountHandler(AgentService agentService) {
this(DEFAULT_REQUEST_MAPPING, agentService);
}
public ActiveThreadCountHandler(String requestMapping, AgentService agentService) {
this(requestMapping, agentService, DEFAULT_FLUSH_DELAY);
}
public ActiveThreadCountHandler(String requestMapping, AgentService agentService, long flushDelay) {
this(requestMapping, agentService, flushDelay, DEFAULT_HEALTH_CHECk_DELAY);
}
public ActiveThreadCountHandler(String requestMapping, AgentService agentService, long flushDelay, long healthCheckDelay) {
this.requestMapping = requestMapping;
this.agentService = agentService;
this.flushDelay = flushDelay;
this.healthCheckDelay = healthCheckDelay;
}
@Override
public void start() {
PinpointThreadFactory flushThreadFactory = new PinpointThreadFactory(ClassUtils.simpleClassName(this) + "-Flush-Thread", true);
webSocketFlushExecutor = new SimpleOrderedThreadPool(Runtime.getRuntime().availableProcessors(), 65535, flushThreadFactory);
flushTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-Flush-Timer", true);
healthCheckTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-HealthCheck-Timer", true);
reactiveTimer = new java.util.Timer(ClassUtils.simpleClassName(this) + "-Reactive-Timer", true);
}
@Override
public void stop() {
for (PinpointWebSocketResponseAggregator aggregator : aggregatorRepository.values()) {
if (aggregator != null) {
aggregator.stop();
}
}
aggregatorRepository.clear();
if (flushTimer != null) {
flushTimer.cancel();
}
if (healthCheckTimer != null) {
healthCheckTimer.cancel();
}
if (reactiveTimer != null) {
reactiveTimer.cancel();
}
if (webSocketFlushExecutor != null) {
webSocketFlushExecutor.shutdown();
}
}
@Override
public String getRequestMapping() {
return requestMapping;
}
@Override
public void afterConnectionEstablished(WebSocketSession newSession) throws Exception {
logger.info("ConnectionEstablished. session:{}", newSession);
synchronized (lock) {
newSession.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(false));
sessionRepository.add(newSession);
boolean turnOn = onTimerTask.compareAndSet(false, true);
if (turnOn) {
flushTimer.schedule(new ActiveThreadTimerTask(flushDelay), flushDelay);
healthCheckTimer.schedule(new HealthCheckTimerTask(), DEFAULT_HEALTH_CHECk_DELAY);
}
}
super.afterConnectionEstablished(newSession);
}
@Override
public void afterConnectionClosed(WebSocketSession closeSession, CloseStatus status) throws Exception {
logger.info("ConnectionClose. session:{}, caused:{}", closeSession, status);
synchronized (lock) {
unbindingResponseAggregator(closeSession);
sessionRepository.remove(closeSession);
if (sessionRepository.isEmpty()) {
boolean turnOff = onTimerTask.compareAndSet(true, false);
}
}
super.afterConnectionClosed(closeSession, status);
}
@Override
protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception {
logger.info("handleTextMessage. session:{}, remote:{}, message:{}.", webSocketSession, webSocketSession.getRemoteAddress(), message.getPayload());
PinpointWebSocketMessage webSocketMessage = messageConverter.getWebSocketMessage(message.getPayload());
PinpointWebSocketMessageType webSocketMessageType = webSocketMessage.getType();
switch (webSocketMessageType) {
case REQUEST:
handleRequestMessage0(webSocketSession, (RequestMessage) webSocketMessage);
break;
case PONG:
handlePongMessage0(webSocketSession, (PongMessage) webSocketMessage);
break;
default:
logger.warn("Unexpected WebSocketMessageType received. messageType:{}.", webSocketMessageType);
}
// this method will be checked socket status.
super.handleTextMessage(webSocketSession, message);
}
private void handleRequestMessage0(WebSocketSession webSocketSession, RequestMessage requestMessage) {
if (serverMapDataFilter != null && serverMapDataFilter.filter(webSocketSession, requestMessage)) {
closeSession(webSocketSession, serverMapDataFilter.getCloseStatus(requestMessage));
return;
}
String command = requestMessage.getCommand();
if (API_ACTIVE_THREAD_COUNT.equals(command)) {
String applicationName = MapUtils.getString(requestMessage.getParams(), APPLICATION_NAME_KEY);
if (applicationName != null) {
synchronized (lock) {
if (StringUtils.equals(applicationName, (String) webSocketSession.getAttributes().get(APPLICATION_NAME_KEY))) {
return;
}
unbindingResponseAggregator(webSocketSession);
if (webSocketSession.isOpen()) {
bindingResponseAggregator(webSocketSession, applicationName);
} else {
logger.warn("WebSocketSession is not opened. skip binding.");
}
}
}
}
}
private void closeSession(WebSocketSession session, CloseStatus status) {
try {
session.close(status);
} catch (Exception e) {
logger.warn(e.getMessage(), e);
}
}
private void handlePongMessage0(WebSocketSession webSocketSession, PongMessage pongMessage) {
Object healthCheckWait = webSocketSession.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (healthCheckWait instanceof AtomicBoolean) {
((AtomicBoolean) healthCheckWait).compareAndSet(true, false);
}
}
@Override
protected void handlePongMessage(WebSocketSession webSocketSession, org.springframework.web.socket.PongMessage message) throws Exception {
logger.info("handlePongMessage. session:{}, remote:{}, message:{}.", webSocketSession, webSocketSession.getRemoteAddress(), message.getPayload());
super.handlePongMessage(webSocketSession, message);
}
private void bindingResponseAggregator(WebSocketSession webSocketSession, String applicationName) {
logger.info("bindingResponseAggregator. session:{}, applicationName:{}.", webSocketSession, applicationName);
webSocketSession.getAttributes().put(APPLICATION_NAME_KEY, applicationName);
if (StringUtils.isEmpty(applicationName)) {
return;
}
PinpointWebSocketResponseAggregator responseAggregator = aggregatorRepository.get(applicationName);
if (responseAggregator == null) {
responseAggregator = new ActiveThreadCountResponseAggregator(applicationName, agentService, reactiveTimer);
responseAggregator.start();
aggregatorRepository.put(applicationName, responseAggregator);
}
responseAggregator.addWebSocketSession(webSocketSession);
}
private void unbindingResponseAggregator(WebSocketSession webSocketSession) {
String applicationName = (String) webSocketSession.getAttributes().get(APPLICATION_NAME_KEY);
logger.info("unbindingResponseAggregator. session:{}, applicationName:{}.", webSocketSession, applicationName);
if (StringUtils.isEmpty(applicationName)) {
return;
}
PinpointWebSocketResponseAggregator responseAggregator = aggregatorRepository.get(applicationName);
if (responseAggregator == null) {
return;
}
boolean cleared = responseAggregator.removeWebSocketSessionAndGetIsCleared(webSocketSession);
if (cleared) {
aggregatorRepository.remove(applicationName);
responseAggregator.stop();
}
}
private class ActiveThreadTimerTask extends java.util.TimerTask {
private final long startTimeMillis;
private final long delay;
private int times = 0;
public ActiveThreadTimerTask(long delay) {
this(System.currentTimeMillis(), delay, 0);
}
public ActiveThreadTimerTask(long startTimeMillis, long delay, int times) {
this.startTimeMillis = startTimeMillis;
this.delay = delay;
this.times = times;
}
@Override
public void run() {
try {
logger.info("ActiveThreadTimerTask started.");
Collection<PinpointWebSocketResponseAggregator> values = aggregatorRepository.values();
for (final PinpointWebSocketResponseAggregator aggregator : values) {
try {
aggregator.flush(webSocketFlushExecutor);
} catch (Exception e) {
logger.warn("failed while flushing ActiveThreadCount to aggregator. applicationName:{}, error:{}", aggregator.getApplicationName(), e.getMessage(), e);
}
}
} finally {
long waitTimeMillis = getWaitTimeMillis();
if (flushTimer != null && onTimerTask.get()) {
flushTimer.schedule(new ActiveThreadTimerTask(startTimeMillis, delay, times), waitTimeMillis);
}
}
}
private long getWaitTimeMillis() {
long waitTime = -1L;
long currentTime = System.currentTimeMillis();
while (waitTime <= 0) {
waitTime = startTimeMillis + (delay * times) - currentTime;
times++;
}
return waitTime;
}
}
private class HealthCheckTimerTask extends java.util.TimerTask {
@Override
public void run() {
try {
logger.info("HealthCheckTimerTask started.");
// check session state.
List<WebSocketSession> webSocketSessionList = new ArrayList<>(sessionRepository);
for (WebSocketSession session : webSocketSessionList) {
if (!session.isOpen()) {
continue;
}
Object untilWait = session.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (untilWait instanceof AtomicBoolean) {
if (((AtomicBoolean) untilWait).get()) {
closeSession(session, CloseStatus.SESSION_NOT_RELIABLE);
}
} else {
session.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(false));
}
}
// send healthCheck packet
String pingTextMessage = messageConverter.getPingTextMessage();
TextMessage pingMessage = new TextMessage(pingTextMessage);
webSocketSessionList = new ArrayList<>(sessionRepository);
for (WebSocketSession session : webSocketSessionList) {
if (!session.isOpen()) {
continue;
}
Object untilWait = session.getAttributes().get(HEALTH_CHECK_WAIT_KEY);
if (untilWait instanceof AtomicBoolean) {
((AtomicBoolean) untilWait).compareAndSet(false, true);
} else {
session.getAttributes().put(HEALTH_CHECK_WAIT_KEY, new AtomicBoolean(true));
}
sendPingMessage(session, pingMessage);
}
} finally {
if (healthCheckTimer != null && onTimerTask.get()) {
healthCheckTimer.schedule(new HealthCheckTimerTask(), healthCheckDelay);
}
}
}
private void sendPingMessage(WebSocketSession session, TextMessage pingMessage) {
try {
webSocketFlushExecutor.execute(new OrderedWebSocketFlushRunnable(session, pingMessage, true));
} catch (RuntimeException e) {
logger.warn("failed while to execute. error:{}.", e.getMessage(), e);
}
}
}
}