package org.ohdsi.webapi.shiro; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.util.logging.Level; import java.util.logging.Logger; import java.util.zip.GZIPInputStream; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.WriteListener; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import org.apache.shiro.web.util.WebUtils; /** * * @author gennadiy.anisimov */ public abstract class ProcessResponseContentFilter implements Filter { @Override public void init(FilterConfig fc) throws ServletException { } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { if (!shouldProcess(request, response)) { chain.doFilter(request, response); return; } if (response.getCharacterEncoding() == null) { response.setCharacterEncoding("UTF-8"); } HttpServletResponseCopier responseCopier = new HttpServletResponseCopier((HttpServletResponse) response); chain.doFilter(request, responseCopier); responseCopier.flushBuffer(); byte[] responseBytes = responseCopier.getCopy(); HttpServletResponse httpResponse = WebUtils.toHttp(response); String responseString; String contentEncoding = httpResponse.getHeader("Content-Encoding"); if ("gzip".equalsIgnoreCase(contentEncoding)) { responseString = this.readGZip(responseBytes, response.getCharacterEncoding()); } else { responseString = new String(responseBytes, response.getCharacterEncoding()); } this.processResponseContent(responseString); } private void processResponseContent(String content) { try { this.doProcessResponseContent(content); } catch (Exception ex) { Logger.getLogger(ProcessResponseContentFilter.class.getName()).log(Level.SEVERE, "Failed to process response content", ex); } } protected abstract boolean shouldProcess(ServletRequest request, ServletResponse response); protected abstract void doProcessResponseContent(String content) throws Exception; protected String parseJsonField(String json, String field) throws IOException { ObjectMapper mapper = new ObjectMapper(); JsonNode rootNode = mapper.readValue(json, JsonNode.class); JsonNode fieldNode = rootNode.get(field); String fieldValue = fieldNode.asText(); return fieldValue; } @Override public void destroy() { } private String readGZip(byte[] data, String encoding) { String decompressed = ""; try { GZIPInputStream stream = new GZIPInputStream(new ByteArrayInputStream(data)); BufferedReader reader = new BufferedReader(new InputStreamReader(stream, encoding)); String line; while ((line = reader.readLine()) != null) { decompressed += line; } } catch (IOException ex) { Logger.getLogger(ProcessResponseContentFilter.class.getName()).log(Level.SEVERE, "Failed decompress gzipped response content", ex); } return decompressed; } protected class HttpServletResponseCopier extends HttpServletResponseWrapper { private ServletOutputStream outputStream; private PrintWriter writer; private ServletOutputStreamCopier copier; public HttpServletResponseCopier(HttpServletResponse response) { super(response); } @Override public ServletOutputStream getOutputStream() throws IOException { if (writer != null) { throw new IllegalStateException("getWriter() has already been called on this response."); } if (outputStream == null) { outputStream = getResponse().getOutputStream(); copier = new ServletOutputStreamCopier(outputStream); } return copier; } @Override public PrintWriter getWriter() throws IOException { if (outputStream != null) { throw new IllegalStateException("getOutputStream() has already been called on this response."); } if (writer == null) { copier = new ServletOutputStreamCopier(getResponse().getOutputStream()); writer = new PrintWriter(new OutputStreamWriter(copier, getResponse().getCharacterEncoding()), true); } return writer; } @Override public void flushBuffer() throws IOException { if (writer != null) { writer.flush(); } else if (outputStream != null) { copier.flush(); } } public byte[] getCopy() { if (copier != null) { return copier.getCopy(); } else { return new byte[0]; } } } public class ServletOutputStreamCopier extends ServletOutputStream { private OutputStream outputStream; private ByteArrayOutputStream copy; public ServletOutputStreamCopier(OutputStream outputStream) { this.outputStream = outputStream; this.copy = new ByteArrayOutputStream(1024); } @Override public void write(int b) throws IOException { outputStream.write(b); copy.write(b); } public byte[] getCopy() { return copy.toByteArray(); } @Override public boolean isReady() { return true; } @Override public void setWriteListener(WriteListener wl) { } } }