package org.ovirt.engine.core.uutils.servlet; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URL; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.util.List; import java.util.Map; import javax.net.ssl.TrustManagerFactory; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.ovirt.engine.core.uutils.net.HttpURLConnectionBuilder; public class ProxyServletBase extends HttpServlet { private static final long serialVersionUID = 5331291232426186121L; private Boolean verifyHost = true; private Boolean verifyChain = true; private String httpsProtocol; private String trustManagerAlgorithm; private String trustStore; private String trustStoreType; private String trustStorePassword = "changeit"; private Integer readTimeout; private String url; protected static long copy(final InputStream input, final OutputStream output) throws IOException { final byte[] buffer = new byte[8*1024]; long count = 0; int n; while ((n = input.read(buffer)) != -1) { output.write(buffer, 0, n); count += n; } return count; } protected void setVerifyHost(Boolean verifyHost) { this.verifyHost = verifyHost; } protected void setVerifyChain(Boolean verifyChain) { this.verifyChain = verifyChain; } protected void setHttpsProtocol(String httpsProtocol) { this.httpsProtocol = httpsProtocol; } protected void setTrustManagerAlgorithm(String trustManagerAlgorithm) { this.trustManagerAlgorithm = trustManagerAlgorithm; } protected void setTrustStore(String trustStore) { this.trustStore = trustStore; } protected void setTrustStoreType(String trustStoreType) { this.trustStoreType = trustStoreType; } protected void setTrustStorePassword(String trustStorePassword) { this.trustStorePassword = trustStorePassword; } protected void setReadTimeout(Integer readTimeout) { this.readTimeout = readTimeout; } protected void setUrl(String url) { this.url = url; } protected HttpURLConnection create(URL url) throws IOException, GeneralSecurityException { return new HttpURLConnectionBuilder(url).setHttpsProtocol(httpsProtocol) .setReadTimeout(readTimeout) .setTrustManagerAlgorithm(trustManagerAlgorithm) .setTrustStore(trustStore) .setTrustStorePassword(trustStorePassword) .setTrustStoreType(trustStoreType) .setURL(url) .setVerifyChain(verifyChain) .setVerifyHost(verifyHost).create(); } private String mergeQuery(String url, String queryString) throws MalformedURLException { String ret = url; if (queryString != null) { URL u = new URL(ret); if (u.getQuery() == null) { ret += "?"; } else { ret += "&"; } ret += queryString; } return ret; } @Override public void init(ServletConfig config) throws ServletException { super.init(config); try { if (verifyHost == null) { verifyHost = true; } if (verifyChain == null) { verifyChain = true; } if (trustManagerAlgorithm == null) { trustManagerAlgorithm = TrustManagerFactory.getDefaultAlgorithm(); } if (trustStoreType == null) { trustStoreType = KeyStore.getDefaultType(); } if (httpsProtocol == null) { httpsProtocol = "TLSv1"; } } catch (Exception e) { throw new ServletException(e); } } @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { if (url == null) { response.sendError(response.SC_NOT_FOUND, "Cannot proxy, no URL is configured."); } else { HttpURLConnection connection = null; try { connection = create(new URL(mergeQuery(url, request.getQueryString()))); connection.setDoInput(true); connection.setDoOutput(false); response.setStatus(connection.getResponseCode()); for (Map.Entry<String, List<String>> entry : connection.getHeaderFields().entrySet()) { if (entry.getKey() != null) { boolean first = true; for (String value : entry.getValue()) { if (first) { first = false; response.setHeader(entry.getKey(), value); } else { response.addHeader(entry.getKey(), value); } } } } copy(connection.getInputStream(), response.getOutputStream()); connection.connect(); } catch (Exception e) { throw new ServletException(e); } finally { if (connection != null) { connection.disconnect(); } } } } }