/*
* Copyright 2016 ThoughtWorks, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.thoughtworks.go.agent;
import com.thoughtworks.go.util.SystemEnvironment;
import com.thoughtworks.go.websocket.Message;
import com.thoughtworks.go.websocket.MessageCallback;
import com.thoughtworks.go.websocket.MessageEncoding;
import org.eclipse.jetty.websocket.api.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static com.thoughtworks.go.util.ExceptionUtils.bomb;
public class WebSocketSessionHandler {
private static final Logger LOG = LoggerFactory.getLogger(WebSocketSessionHandler.class);
// This is a session aware socket
private Session session;
private String sessionName = "[No Session]";
private final Map<String, MessageCallback> callbacks = new ConcurrentHashMap<>();
private SystemEnvironment systemEnvironment;
@Autowired
public WebSocketSessionHandler(SystemEnvironment systemEnvironment) {
this.systemEnvironment = systemEnvironment;
}
synchronized void stop() {
if (isRunning()) {
LOG.debug("close {}", sessionName());
session.close();
session = null;
sessionName = "[No Session]";
}
}
private synchronized boolean isRunning() {
return session != null && session.isOpen();
}
synchronized boolean isNotRunning() {
return !isRunning();
}
private void send(Message message) {
for (int retries = 1; retries <= systemEnvironment.getWebsocketSendRetryCount(); retries++) {
try {
LOG.debug("{} attempt {} to send message: {}", sessionName(), retries, message);
session.getRemote().sendBytesByFuture(ByteBuffer.wrap(MessageEncoding.encodeMessage(message)));
break;
} catch (Throwable e) {
try {
LOG.debug("{} attempt {} failed to send message: {}.", sessionName(), retries, message);
if (retries == systemEnvironment.getWebsocketSendRetryCount()) {
bomb(e);
}
Thread.sleep(2000L);
} catch (InterruptedException ignored) {
}
}
}
}
void sendAndWaitForAcknowledgement(Message message) {
final CountDownLatch wait = new CountDownLatch(1);
sendWithCallback(message, new MessageCallback() {
@Override
public void call() {
wait.countDown();
}
});
try {
wait.await(systemEnvironment.getWebsocketAckMessageTimeout(), TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
bomb(e);
}
}
private void sendWithCallback(Message message, MessageCallback callback) {
callbacks.put(message.getAcknowledgementId(), callback);
send(message);
}
private String sessionName() {
return session == null ? "[No session initialized]" : "Session[" + session.getRemoteAddress() + "]";
}
void setSession(Session session) {
this.session = session;
this.sessionName = "[" + session.getRemoteAddress() + "]";
}
String getSessionName() {
return sessionName;
}
void acknowledge(Message message) {
String acknowledgementId = MessageEncoding.decodeData(message.getData(), String.class);
LOG.debug("Acknowledging {}", acknowledgementId);
callbacks.remove(acknowledgementId).call();
}
void clearCallBacks() {
LOG.debug("Clearing {} ignored messages", callbacks.size());
callbacks.clear();
}
}