/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.notebook.repo.zeppelinhub.websocket;
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.apache.commons.lang3.StringUtils;
import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.notebook.NotebookAuthorization;
import org.apache.zeppelin.notebook.repo.zeppelinhub.model.UserTokenContainer;
import org.apache.zeppelin.notebook.repo.zeppelinhub.security.Authentication;
import org.apache.zeppelin.notebook.repo.zeppelinhub.websocket.listener.WatcherWebsocket;
import org.apache.zeppelin.notebook.repo.zeppelinhub.websocket.listener.ZeppelinWebsocket;
import org.apache.zeppelin.notebook.repo.zeppelinhub.websocket.protocol.ZeppelinhubMessage;
import org.apache.zeppelin.notebook.repo.zeppelinhub.websocket.scheduler.SchedulerService;
import org.apache.zeppelin.notebook.repo.zeppelinhub.websocket.scheduler.ZeppelinHeartbeat;
import org.apache.zeppelin.notebook.socket.Message;
import org.apache.zeppelin.notebook.socket.Message.OP;
import org.apache.zeppelin.util.WatcherSecurityKey;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
/**
* Zeppelin websocket client.
*
*/
public class ZeppelinClient {
private static final Logger LOG = LoggerFactory.getLogger(ZeppelinClient.class);
private final URI zeppelinWebsocketUrl;
private final WebSocketClient wsClient;
private static Gson gson;
// Keep track of current open connection per notebook.
private ConcurrentHashMap<String, Session> notesConnection;
// Listen to every note actions.
private static Session watcherSession;
private static ZeppelinClient instance = null;
private SchedulerService schedulerService;
private Authentication authModule;
private static final int MIN = 60;
private static final String ORIGIN = "Origin";
private static final Set<String> actionable = new HashSet<String>(Arrays.asList(
// running events
"ANGULAR_OBJECT_UPDATE",
"PROGRESS",
"NOTE",
"PARAGRAPH",
"PARAGRAPH_UPDATE_OUTPUT",
"PARAGRAPH_APPEND_OUTPUT",
"PARAGRAPH_CLEAR_OUTPUT",
"PARAGRAPH_REMOVE",
// run or stop events
"RUN_PARAGRAPH",
"CANCEL_PARAGRAPH"));
public static ZeppelinClient initialize(String zeppelinUrl, String token,
ZeppelinConfiguration conf) {
if (instance == null) {
instance = new ZeppelinClient(zeppelinUrl, token, conf);
}
return instance;
}
public static ZeppelinClient getInstance() {
return instance;
}
private ZeppelinClient(String zeppelinUrl, String token, ZeppelinConfiguration conf) {
zeppelinWebsocketUrl = URI.create(zeppelinUrl);
wsClient = createNewWebsocketClient();
gson = new Gson();
notesConnection = new ConcurrentHashMap<>();
schedulerService = SchedulerService.getInstance();
authModule = Authentication.initialize(token, conf);
if (authModule != null) {
SchedulerService.getInstance().addOnce(authModule, 10);
}
LOG.info("Initialized Zeppelin websocket client on {}", zeppelinWebsocketUrl);
}
private WebSocketClient createNewWebsocketClient() {
SslContextFactory sslContextFactory = new SslContextFactory();
WebSocketClient client = new WebSocketClient(sslContextFactory);
client.setMaxIdleTimeout(5 * MIN * 1000);
client.setMaxTextMessageBufferSize(Client.getMaxNoteSize());
client.getPolicy().setMaxTextMessageSize(Client.getMaxNoteSize());
//TODO(khalid): other client settings
return client;
}
public void start() {
try {
if (wsClient != null) {
wsClient.start();
addRoutines();
} else {
LOG.warn("Cannot start zeppelin websocket client - isn't initialized");
}
} catch (Exception e) {
LOG.error("Cannot start Zeppelin websocket client", e);
}
}
private void addRoutines() {
schedulerService.add(ZeppelinHeartbeat.newInstance(this), 10, 1 * MIN);
new Timer().schedule(new java.util.TimerTask() {
@Override
public void run() {
int time = 0;
while (time < 5 * MIN) {
watcherSession = openWatcherSession();
if (watcherSession == null) {
try {
Thread.sleep(5000);
time += 5;
} catch (InterruptedException e) {
//continue
}
} else {
break;
}
}
}
}, 5000);
}
public void stop() {
try {
if (wsClient != null) {
removeAllConnections();
wsClient.stop();
} else {
LOG.warn("Cannot stop zeppelin websocket client - isn't initialized");
}
if (watcherSession != null) {
watcherSession.close();
}
} catch (Exception e) {
LOG.error("Cannot stop Zeppelin websocket client", e);
}
}
public String serialize(Message zeppelinMsg) {
if (credentialsAvailable()) {
zeppelinMsg.principal = authModule.getPrincipal();
zeppelinMsg.ticket = authModule.getTicket();
zeppelinMsg.roles = authModule.getRoles();
}
String msg = gson.toJson(zeppelinMsg);
return msg;
}
private boolean credentialsAvailable() {
return Authentication.getInstance() != null && Authentication.getInstance().isAuthenticated();
}
public Message deserialize(String zeppelinMessage) {
if (StringUtils.isBlank(zeppelinMessage)) {
return null;
}
Message msg;
try {
msg = gson.fromJson(zeppelinMessage, Message.class);
} catch (JsonSyntaxException ex) {
LOG.error("Cannot deserialize zeppelin message", ex);
msg = null;
}
return msg;
}
private Session openWatcherSession() {
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setHeader(WatcherSecurityKey.HTTP_HEADER, WatcherSecurityKey.getKey());
request.setHeader(ORIGIN, "*");
WatcherWebsocket socket = WatcherWebsocket.createInstace();
Future<Session> future = null;
Session session = null;
try {
future = wsClient.connect(socket, zeppelinWebsocketUrl, request);
session = future.get();
} catch (IOException | InterruptedException | ExecutionException e) {
LOG.error("Couldn't establish websocket connection to Zeppelin ", e);
return session;
}
return session;
}
public void send(Message msg, String noteId) {
Session noteSession = getZeppelinConnection(noteId, msg.principal, msg.ticket);
if (!isSessionOpen(noteSession)) {
LOG.error("Cannot open websocket connection to Zeppelin note {}", noteId);
return;
}
noteSession.getRemote().sendStringByFuture(serialize(msg));
}
public Session getZeppelinConnection(String noteId, String principal, String ticket) {
if (StringUtils.isBlank(noteId)) {
LOG.warn("Cannot get Websocket session with blanck noteId");
return null;
}
return getNoteSession(noteId, principal, ticket);
}
/*
private Message zeppelinGetNoteMsg(String noteId) {
Message getNoteMsg = new Message(Message.OP.GET_NOTE);
HashMap<String, Object> data = new HashMap<>();
data.put("id", noteId);
getNoteMsg.data = data;
return getNoteMsg;
}
*/
private Session getNoteSession(String noteId, String principal, String ticket) {
LOG.info("Getting Note websocket connection for note {}", noteId);
Session session = notesConnection.get(noteId);
if (!isSessionOpen(session)) {
LOG.info("No open connection for note {}, opening one", noteId);
notesConnection.remove(noteId);
session = openNoteSession(noteId, principal, ticket);
}
return session;
}
private Session openNoteSession(String noteId, String principal, String ticket) {
ClientUpgradeRequest request = new ClientUpgradeRequest();
request.setHeader(ORIGIN, "*");
ZeppelinWebsocket socket = new ZeppelinWebsocket(noteId);
Future<Session> future = null;
Session session = null;
try {
future = wsClient.connect(socket, zeppelinWebsocketUrl, request);
session = future.get();
} catch (IOException | InterruptedException | ExecutionException e) {
LOG.error("Couldn't establish websocket connection to Zeppelin ", e);
return session;
}
if (notesConnection.containsKey(noteId)) {
session.close();
session = notesConnection.get(noteId);
} else {
String getNote = serialize(zeppelinGetNoteMsg(noteId, principal, ticket));
session.getRemote().sendStringByFuture(getNote);
notesConnection.put(noteId, session);
}
return session;
}
private boolean isSessionOpen(Session session) {
return (session != null) && (session.isOpen());
}
private Message zeppelinGetNoteMsg(String noteId, String principal, String ticket) {
Message getNoteMsg = new Message(Message.OP.GET_NOTE);
HashMap<String, Object> data = new HashMap<String, Object>();
data.put("id", noteId);
getNoteMsg.data = data;
getNoteMsg.principal = principal;
getNoteMsg.ticket = ticket;
return getNoteMsg;
}
public void handleMsgFromZeppelin(String message, String noteId) {
Map<String, String> meta = new HashMap<>();
//TODO(khalid): don't use zeppelinhubToken in this class, decouple
meta.put("noteId", noteId);
Message zeppelinMsg = deserialize(message);
if (zeppelinMsg == null) {
return;
}
String token;
if (!isActionable(zeppelinMsg.op)) {
return;
}
token = UserTokenContainer.getInstance().getUserToken(zeppelinMsg.principal);
Client client = Client.getInstance();
if (client == null) {
LOG.warn("Client isn't initialized yet");
return;
}
ZeppelinhubMessage hubMsg = ZeppelinhubMessage.newMessage(zeppelinMsg, meta);
if (StringUtils.isEmpty(token)) {
relayToAllZeppelinHub(hubMsg, noteId);
} else {
client.relayToZeppelinHub(hubMsg.serialize(), token);
}
}
private void relayToAllZeppelinHub(ZeppelinhubMessage hubMsg, String noteId) {
if (StringUtils.isBlank(noteId)) {
return;
}
NotebookAuthorization noteAuth = NotebookAuthorization.getInstance();
Map<String, String> userTokens = UserTokenContainer.getInstance().getAllUserTokens();
Client client = Client.getInstance();
Set<String> userAndRoles;
String token;
for (String user: userTokens.keySet()) {
userAndRoles = noteAuth.getRoles(user);
userAndRoles.add(user);
if (noteAuth.isReader(noteId, userAndRoles)) {
token = userTokens.get(user);
hubMsg.meta.put("token", token);
client.relayToZeppelinHub(hubMsg.serialize(), token);
}
}
}
private boolean isActionable(OP action) {
if (action == null) {
return false;
}
return actionable.contains(action.name());
}
public void removeNoteConnection(String noteId) {
if (StringUtils.isBlank(noteId)) {
LOG.error("Cannot remove session for empty noteId");
return;
}
if (notesConnection.containsKey(noteId)) {
Session connection = notesConnection.get(noteId);
if (connection.isOpen()) {
connection.close();
}
notesConnection.remove(noteId);
}
LOG.info("Removed note websocket connection for note {}", noteId);
}
private void removeAllConnections() {
if (watcherSession != null && watcherSession.isOpen()) {
watcherSession.close();
}
Session noteSession = null;
for (Map.Entry<String, Session> note: notesConnection.entrySet()) {
noteSession = note.getValue();
if (isSessionOpen(noteSession)) {
noteSession.close();
}
}
notesConnection.clear();
}
public void ping() {
if (watcherSession == null) {
LOG.debug("Cannot send PING event, no watcher found");
return;
}
watcherSession.getRemote().sendStringByFuture(serialize(new Message(OP.PING)));
}
/**
* Only used in test.
*/
public int countConnectedNotes() {
return notesConnection.size();
}
}