/** * Copyright (c) Codice Foundation * <p> * This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser * General Public License as published by the Free Software Foundation, either version 3 of the * License, or any later version. * <p> * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. A copy of the GNU Lesser General Public License * is distributed along with this program and can be found at * <http://www.gnu.org/licenses/lgpl.html>. */ package org.codice.ddf.itests.common.cometd; import static java.time.format.DateTimeFormatter.ISO_DATE_TIME; import java.net.ConnectException; import java.net.URI; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.time.LocalDateTime; import java.time.LocalTime; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.StringUtils; import org.cometd.bayeux.Channel; import org.cometd.bayeux.Message; import org.cometd.bayeux.client.ClientSessionChannel; import org.cometd.client.BayeuxClient; import org.cometd.client.transport.ClientTransport; import org.cometd.client.transport.LongPollingTransport; import org.eclipse.jetty.client.HttpClient; import org.eclipse.jetty.client.util.BasicAuthentication; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.jayway.restassured.path.json.JsonPath; /** * CometD client used to listen for messages on CometD channels and interact with DDF's CometD * endpoint. * <p> * Below is an example on how to listen for messages on the notifications channel: * <p> * <pre> * // Creates a CometDClient that will connect to the CometD Server at the specified URL. * CometDClient cometDClient = new CometDClient(cometDEndpointUrl); * * // Starts the cometDClient and performs the initial handshake with the CometD server * cometDClient.start(); * * // Subscribes to the notifications channel * cometDClient.subscribe("/ddf/notifications/**")); * * // Retrieves messages on all subscribed channels (in this example, receives messages on * // on the notifications channel). * List<String> messages = cometDClient.getAllMessages(); * * // Shutdown the cometD Client and un-subscribes from all channels * cometDClient.shutdown(); * </pre> */ public class CometDClient { private static final Logger LOGGER = LoggerFactory.getLogger(CometDClient.class); private static final long TIMEOUT = TimeUnit.SECONDS.toMillis(60); private static final long MAX_NETWORK_DELAY = 60000; private static final String QUERY_PUBLISH_CHANNEL = "/service/query"; private final BayeuxClient bayeuxClient; private final HttpClient httpClient; private final List<MessageListener> messageListeners = new ArrayList<>(); private static final String ACTIVITIES_CHANNEL = "/ddf/activities/**"; private static final String DATA_MESSAGE = "data.message"; private static final String DATA_ID = "data.id"; private static final String DOWNLOAD_CANCELLED = "Resource retrieval cancelled"; private static final String DOWNLOAD_COMPLETED = "Resource retrieval completed"; private static final String RETRIEVAL_FAILURE = "Unable to retrieve"; private Set<String> downloadIds = new HashSet<>(); /** * Creates a CometD client without authentication. * * @param url CometD endpoint * @throws Exception thrown if client setup fails */ public CometDClient(String url) throws Exception { SslContextFactory sslContextFactory = new SslContextFactory(true); httpClient = new HttpClient(sslContextFactory); doTrustAllCertificates(); ClientTransport transport = new LongPollingTransport(new HashMap<>(), httpClient); transport.setOption(ClientTransport.MAX_NETWORK_DELAY_OPTION, MAX_NETWORK_DELAY); bayeuxClient = new BayeuxClient(url, transport); } /** * Creates a CometD client with authentication. * * @param url CometD endpoint * @param realm security realm * @param username user name * @param password password * @throws Exception thrown if client setup fails */ public CometDClient(String url, String realm, String username, String password) throws Exception { this(url); URI uri = new URI(url); httpClient.getAuthenticationStore() .addAuthentication(new BasicAuthentication(uri, realm, username, password)); } /** * Starts the client. * * @throws Exception thrown if the client fails to start */ public void start() throws Exception { httpClient.start(); LOGGER.debug("HTTP client started: {}", httpClient.isStarted()); MessageListener handshakeListener = new MessageListener(Channel.META_HANDSHAKE); bayeuxClient.getChannel(Channel.META_HANDSHAKE) .addListener(handshakeListener); bayeuxClient.handshake(); boolean connected = bayeuxClient.waitFor(TIMEOUT, BayeuxClient.State.CONNECTED); if (!connected) { shutdownHttpClient(); String message = String.format("%s failed to connect to the server at %s", this.getClass() .getName(), bayeuxClient.getURL()); LOGGER.error(message); throw new ConnectException(message); } } /** * Publishes a message * * @param channel channel to publish message to * @param message message to publish */ public void publish(String channel, Map<String, Object> message) { LOGGER.debug("Publishing message {} to channel {}", message, channel); bayeuxClient.getChannel(channel) .publish(message, (responseChannel, responseMessage) -> LOGGER.debug( "Response {} received for message {} on channel {}", responseMessage.getJSON(), responseChannel, message)); } /** * Subscribes to a channel. Subscribing to the same channel multiple times has no effect. * * @param channel channel name */ public void subscribe(String channel) { verifyConnected(); if (!alreadySubscribed(channel)) { ClientSessionChannel clientSessionChannel = bayeuxClient.getChannel(channel); MessageListener messageListener = new MessageListener(channel); clientSessionChannel.subscribe(messageListener); messageListeners.add(messageListener); } else { LOGGER.debug("Already subscribed to channel {}", channel); } } /** * Gets the first message that matches the search criterion * * @param searchCriterion a string that will be searched for in the messages * @return the desired message if found */ public Optional<String> searchMessages(String searchCriterion) { List<String> messages = getAllMessages(); return messages.stream() .filter(query -> query.contains(searchCriterion)) .findFirst(); } /** * Gets the list of messages received on a given channel. * * @param channel channel name * @return list of message received since the client was started */ public List<String> getMessages(String channel) { verifyConnected(); verifySubscribed(); return messageListeners.stream() .filter(l -> l.getChannel() .equals(channel)) .flatMap(l -> l.getMessages() .stream()) .collect(Collectors.toList()); } /** * Gets the list of messages received on a given channel in time ascending order, i.e., from * oldest to most recent. * * @param channel channel name * @return list of message received since the client was started */ public List<String> getMessagesInAscOrder(String channel) { List<String> messages = getMessages(channel); Collections.sort(messages, new AscendingTimestampComparator()); return messages; } /** * Gets the CometD client ID. * * @return CometD client ID */ public String getClientId() { verifyConnected(); return bayeuxClient.getId(); } /** * Gets the list of messages received on all channels. * * @return list of message received since the client was started */ public List<String> getAllMessages() { verifyConnected(); verifySubscribed(); return messageListeners.stream() .flatMap(l -> l.getMessages() .stream()) .collect(Collectors.toList()); } /** * Gets the list of messages received on all channels in time ascending order, i.e., from * oldest to most recent. * * @return list of message received since the client was started */ public List<String> getAllMessagesInAscOrder() { List<String> messages = getAllMessages(); Collections.sort(messages, new AscendingTimestampComparator()); return messages; } /** * Publishes a search message for a specific metacard ID. * * @param responseChannel ID of the channel where the response should be sent * @param source source to query * @param metacardId ID of the metacard to retrieve */ public void searchByMetacardId(String responseChannel, String source, String metacardId) { Map<String, Object> data = new HashMap<>(); data.put("cql", String.format("(\"anyText\" ILIKE '%s')", metacardId)); data.put("id", responseChannel); data.put("federation", "enterprise"); data.put("src", source); data.put("radiusUnits", "meters"); data.put("count", 250L); data.put("start", 1L); data.put("format", "geojson"); data.put("scheduleUnits", "minutes"); data.put("timeType", "modified"); data.put("locationType", "latlon"); data.put("sort", "modified:desc"); data.put("q", metacardId); data.put("sortOrder", "desc"); data.put("sortField", "modified"); data.put("radius", "0"); publish(QUERY_PUBLISH_CHANNEL, data); } /** * Cancels a resource download. * * @param downloadId ID of the download to cancel */ public void cancelDownload(String downloadId) { Map<String, Object> jsonMap = new HashMap<>(); List<Map<String, Object>> data = new ArrayList<>(); Map<String, Object> dataMap = new HashMap<>(); dataMap.put("id", downloadId); dataMap.put("action", "cancel"); data.add(dataMap); jsonMap.put("data", data); publish("/service/action", jsonMap); } /** * Cancels all resource downloads. */ public void cancelAllDownloads() { downloadIds.stream() .forEach(this::cancelDownload); } /** * Un-subscribes from a channel. Un-subscribing from the same channel multiple has no effect. * * @param channel channel name */ public void unsubscribe(String channel) { verifyConnected(); Optional<MessageListener> optionalMessageListener = messageListeners.stream() .filter(l -> l.getChannel() .equals(channel)) .findFirst(); optionalMessageListener.ifPresent((messageListener) -> { bayeuxClient.getChannel(channel) .unsubscribe(messageListener); messageListeners.remove(messageListener); }); } /** * Un-subscribes from all channels. */ public void unsubscribeFromAllChannels() { verifyConnected(); messageListeners.forEach(l -> bayeuxClient.getChannel(l.getChannel()) .unsubscribe(l)); messageListeners.clear(); } /** * Shuts down the client. * * @throws Exception thrown if the shutdown fails */ public void shutdown() throws Exception { verifyConnected(); LOGGER.debug("{} is shutting down!", this.getClass() .getName()); unsubscribeFromAllChannels(); httpClient.stop(); bayeuxClient.disconnect(); bayeuxClient.waitFor(TIMEOUT, BayeuxClient.State.DISCONNECTED); } private void verifyConnected() { if (!bayeuxClient.isConnected()) { String message = String.format("%s has not connected to the server at %s", this.getClass() .getName(), bayeuxClient.getURL()); LOGGER.error(message); throw new IllegalStateException(message); } } private void verifySubscribed() { if (CollectionUtils.isEmpty(messageListeners)) { String message = String.format("%s is not subscribed to any channels", this.getClass() .getName()); throw new IllegalStateException(message); } } private void shutdownHttpClient() throws Exception { if (!httpClient.isStopped()) { LOGGER.debug("Stopping http client."); httpClient.stop(); } } private void doTrustAllCertificates() throws NoSuchAlgorithmException, KeyManagementException { TrustManager[] trustAllCerts = new TrustManager[] {new X509TrustManager() { @Override public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { } @Override public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { } @Override public X509Certificate[] getAcceptedIssuers() { return null; } }}; SSLContext sslContext = SSLContext.getInstance("SSL"); sslContext.init(null, trustAllCerts, new SecureRandom()); HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); HostnameVerifier hostnameVerifier = (s, sslSession) -> s.equalsIgnoreCase(sslSession.getPeerHost()); HttpsURLConnection.setDefaultHostnameVerifier(hostnameVerifier); } private boolean alreadySubscribed(String channel) { return messageListeners.stream() .filter(l -> l.getChannel() .equals(channel)) .count() == 1; } public Set<String> getDownloadIds() { return Collections.unmodifiableSet(downloadIds); } private class MessageListener implements org.cometd.bayeux.client.ClientSessionChannel.MessageListener { private final List<String> messages; private final String channel; MessageListener(String channel) { messages = Collections.synchronizedList(new ArrayList<>()); this.channel = channel; } @Override public void onMessage(ClientSessionChannel channel, Message message) { LOGGER.debug("On channel {} received message {}", channel, message.getJSON()); LOGGER.debug("timestamp of message: {}", message.getDataAsMap() .get("timestamp")); JsonPath jsonPath = JsonPath.from(message.getJSON() .toString()); String dataMessage = jsonPath.getString(DATA_MESSAGE); if (channel.getChannelId() .toString() .equals(ACTIVITIES_CHANNEL)) { String downloadId = jsonPath.getString(DATA_ID); if (StringUtils.isNotEmpty(downloadId)) { if (dataMessage.contains(DOWNLOAD_CANCELLED) || dataMessage.contains( DOWNLOAD_COMPLETED) || dataMessage.contains(RETRIEVAL_FAILURE)) { downloadIds.remove(downloadId); } else { downloadIds.add(downloadId); } } } messages.add(message.getJSON()); } private String getChannel() { return channel; } private List<String> getMessages() { return messages; } } private class AscendingTimestampComparator implements Comparator<String> { private static final String TIMESTAMP_PATH = "data.timestamp"; @Override public int compare(String jsonMessage1, String jsonMessage2) { LocalTime time1 = getLocalTime(jsonMessage1); LocalTime time2 = getLocalTime(jsonMessage2); return time1.compareTo(time2); } private LocalTime getLocalTime(String jsonMessage) { JsonPath jsonPath = JsonPath.from(jsonMessage); String timestamp = jsonPath.getString(TIMESTAMP_PATH); return LocalDateTime.parse(timestamp, ISO_DATE_TIME) .toLocalTime(); } } }