package io.lumify.analystsNotebook.aggregateClassification; import io.lumify.core.config.Configuration; import io.lumify.core.exception.LumifyException; import io.lumify.core.model.properties.LumifyProperties; import io.lumify.core.util.LumifyLogger; import io.lumify.core.util.LumifyLoggerFactory; import io.lumify.web.clientapi.model.VisibilityJson; import org.apache.commons.io.IOUtils; import org.securegraph.Vertex; import javax.net.ssl.*; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URL; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.util.HashSet; import java.util.Set; public class AggregateClassificationClient { private static final LumifyLogger LOGGER = LumifyLoggerFactory.getLogger(AggregateClassificationClient.class); private AggregateClassificationConfiguration aggregateClassificationConfiguration; public AggregateClassificationClient(Configuration configuration) { aggregateClassificationConfiguration = new AggregateClassificationConfiguration(); configuration.setConfigurables(aggregateClassificationConfiguration, AggregateClassificationConfiguration.CONFIGURATION_PREFIX); } public String getClassificationBanner(Iterable<Vertex> vertices) { if (!aggregateClassificationConfiguration.isServiceConfigured()) { return null; } String[] visibilitySources = getUniqueVisibilitySources(vertices); try { URL url = getURL(visibilitySources); LOGGER.debug("aggregate classification request url is: %s", url); HttpURLConnection httpConnection; if (url.getProtocol().equalsIgnoreCase("https")) { HttpsURLConnection httpsConnection = (HttpsURLConnection) url.openConnection(); if (aggregateClassificationConfiguration.isTrustStoreConfigured()) { LOGGER.debug("configuring SSLSocketFactory with custom TrustManager for https connection"); httpsConnection.setSSLSocketFactory(getSSLSocketFactory()); } if (aggregateClassificationConfiguration.isDisableHostnameVerification()) { LOGGER.debug("disabling host name verification for https connection"); httpsConnection.setHostnameVerifier(getHostnameVerifier()); } httpsConnection.connect(); httpConnection = httpsConnection; } else { httpConnection = (HttpURLConnection) url.openConnection(); httpConnection.connect(); } int responseCode = httpConnection.getResponseCode(); if (responseCode != HttpURLConnection.HTTP_OK) { String responseMessage = httpConnection.getResponseMessage(); throw new LumifyException(responseCode + " (" + responseMessage + ") while accessing: " + aggregateClassificationConfiguration.getServiceUrl()); } String content = IOUtils.toString(httpConnection.getInputStream(), httpConnection.getContentEncoding()); LOGGER.debug("aggregate classification response content is: %s", content); return content; } catch (Exception e) { throw new LumifyException("exception while making the aggregate classification request", e); } } private String[] getUniqueVisibilitySources(Iterable<Vertex> vertices) { Set<String> visibilitySourceSet = new HashSet<String>(); for (Vertex vertex : vertices) { VisibilityJson visibilityJson = LumifyProperties.VISIBILITY_JSON.getPropertyValue(vertex); if (visibilityJson != null) { String visibilitySource = visibilityJson.getSource(); if (visibilitySource != null && visibilitySource.trim().length() > 0) { visibilitySourceSet.add(visibilitySource); } } } return visibilitySourceSet.toArray(new String[visibilitySourceSet.size()]); } private URL getURL(String[] visibilitySources) throws MalformedURLException { String serviceUrl = aggregateClassificationConfiguration.getServiceUrl(); String parameterName = aggregateClassificationConfiguration.getParameterName(); StringBuilder sb = new StringBuilder(); sb.append(serviceUrl); for (int i = 0; i < visibilitySources.length; i++) { sb.append(i == 0 ? "?" : "&").append(parameterName).append("=").append(visibilitySources[i]); } return new URL(sb.toString()); } private SSLSocketFactory getSSLSocketFactory() throws GeneralSecurityException, IOException { KeyManager[] keyManagers = new KeyManager[]{}; TrustManager[] trustManagers = getTrustManagers(); SSLContext sslContext = SSLContext.getInstance("TLSv1"); sslContext.init(keyManagers, trustManagers, null); return sslContext.getSocketFactory(); } private TrustManager[] getTrustManagers() throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException { File trustStoreFile = new File(aggregateClassificationConfiguration.getTrustStorePath()); FileInputStream trustStoreFileInputStream = new FileInputStream(trustStoreFile); KeyStore trustStore = KeyStore.getInstance("JKS"); // TODO: choose the type by file extension char[] trustStorePassword = aggregateClassificationConfiguration.getTrustStorePassword().toCharArray(); trustStore.load(trustStoreFileInputStream, trustStorePassword); TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(trustStore); return trustManagerFactory.getTrustManagers(); } private HostnameVerifier getHostnameVerifier() { HostnameVerifier hostnameVerifier = new HostnameVerifier() { @Override public boolean verify(String s, SSLSession sslSession) { return true; } }; return hostnameVerifier; } }