/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.wink.server.internal.servlet.contentencode;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.InflaterInputStream;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.HttpHeaders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A servlet filter which changes the HttpServletRequest to automatically
* inflate or GZIP decode an incoming request that has an appropriate
* Content-Encoding request header value. Add to your web.xml like: <br/>
* <code>
* <filter><br/>
<filter-name>ContentEncodingRequestFilter</filter-name><br/>
<filter-class>org.apache.wink.server.internal.servlet.contentencode.ContentEncodingRequestFilter</filter-class><br/>
</filter><br/>
<br/>
<filter-mapping><br/>
<filter-name>ContentEncodingRequestFilter</filter-name><br/>
<url-pattern>/*</url-pattern><br/>
</filter-mapping><br/>
* </code>
*/
public class ContentEncodingRequestFilter implements Filter {
private static final Logger logger =
LoggerFactory
.getLogger(ContentEncodingRequestFilter.class);
public void init(FilterConfig arg0) throws ServletException {
logger.trace("init({}) entry", arg0); //$NON-NLS-1$
/* do nothing */
logger.trace("init() exit"); //$NON-NLS-1$
}
public void destroy() {
logger.trace("destroy() entry"); //$NON-NLS-1$
/* do nothing */
logger.trace("destroy() exit"); //$NON-NLS-1$
}
private String getContentEncoding(HttpServletRequest httpServletRequest) {
String contentEncoding = httpServletRequest.getHeader(HttpHeaders.CONTENT_ENCODING);
if (contentEncoding == null) {
return null;
}
contentEncoding.trim();
return contentEncoding;
}
public void doFilter(ServletRequest servletRequest,
ServletResponse servletResponse,
FilterChain chain) throws IOException, ServletException {
if (logger.isTraceEnabled()) {
logger.trace("doFilter({}, {}, {}) entry", new Object[] {servletRequest, //$NON-NLS-1$
servletResponse, chain});
}
if (servletRequest instanceof HttpServletRequest && servletResponse instanceof HttpServletResponse) {
HttpServletRequest httpServletRequest = (HttpServletRequest)servletRequest;
String contentEncoding = getContentEncoding(httpServletRequest);
logger.trace("Content-Encoding was {}", contentEncoding); //$NON-NLS-1$
if (contentEncoding != null) {
if ("gzip".equals(contentEncoding) || "deflate".equals(contentEncoding)) { //$NON-NLS-1$ //$NON-NLS-2$
logger
.trace("Wrapping HttpServletRequest because Content-Encoding was set to gzip or deflate"); //$NON-NLS-1$
httpServletRequest =
new HttpServletRequestContentEncodingWrapperImpl(httpServletRequest,
contentEncoding);
logger.trace("Invoking chain with wrapped HttpServletRequest"); //$NON-NLS-1$
chain.doFilter(httpServletRequest, servletResponse);
logger.trace("doFilter exit()"); //$NON-NLS-1$
return;
}
}
}
logger
.trace("Invoking normal chain since Content-Encoding request header was not understood"); //$NON-NLS-1$
chain.doFilter(servletRequest, servletResponse);
logger.trace("doFilter exit()"); //$NON-NLS-1$
}
static class DecoderServletInputStream extends ServletInputStream {
final private InputStream is;
public DecoderServletInputStream(InputStream is) {
this.is = is;
}
@Override
public int readLine(byte[] b, int off, int len) throws IOException {
return is.read(b, off, len);
}
@Override
public int available() throws IOException {
return is.available();
}
@Override
public void close() throws IOException {
is.close();
}
@Override
public synchronized void mark(int readlimit) {
is.mark(readlimit);
}
@Override
public boolean markSupported() {
return is.markSupported();
}
@Override
public int read() throws IOException {
return is.read();
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
return is.read(b, off, len);
}
@Override
public int read(byte[] b) throws IOException {
return is.read(b);
}
@Override
public synchronized void reset() throws IOException {
is.reset();
}
@Override
public long skip(long n) throws IOException {
return is.skip(n);
}
}
static class GZIPDecoderInputStream extends DecoderServletInputStream {
public GZIPDecoderInputStream(InputStream is) throws IOException {
super(new GZIPInputStream(is));
}
}
static class InflaterDecoderInputStream extends DecoderServletInputStream {
public InflaterDecoderInputStream(InputStream is) {
super(new InflaterInputStream(is));
}
}
static class HttpServletRequestContentEncodingWrapperImpl extends HttpServletRequestWrapper {
private ServletInputStream inputStream;
final private String contentEncoding;
public HttpServletRequestContentEncodingWrapperImpl(HttpServletRequest request,
String contentEncoding) {
super(request);
this.contentEncoding = contentEncoding;
}
@Override
public ServletInputStream getInputStream() throws IOException {
logger.trace("getInputStream() entry"); //$NON-NLS-1$
if (inputStream == null) {
inputStream = super.getInputStream();
if ("gzip".equals(contentEncoding)) { //$NON-NLS-1$
logger.trace("Wrapping ServletInputStream with GZIPDecoder"); //$NON-NLS-1$
inputStream = new GZIPDecoderInputStream(inputStream);
} else if ("deflate".equals(contentEncoding)) { //$NON-NLS-1$
logger.trace("Wrapping ServletInputStream with Inflater"); //$NON-NLS-1$
inputStream = new InflaterDecoderInputStream(inputStream);
}
}
logger.trace("getInputStream() exit - returning {}", inputStream); //$NON-NLS-1$
return inputStream;
}
@Override
public String getHeader(String name) {
if (HttpHeaders.CONTENT_ENCODING.equalsIgnoreCase(name)) {
return null;
}
return super.getHeader(name);
}
@SuppressWarnings("unchecked")
@Override
public Enumeration<String> getHeaders(String name) {
if (HttpHeaders.CONTENT_ENCODING.equalsIgnoreCase(name)) {
// an empty enumeration
return new Enumeration<String>() {
public boolean hasMoreElements() {
return false;
}
public String nextElement() {
return null;
}
};
}
return super.getHeaders(name);
}
@SuppressWarnings("unchecked")
@Override
public Enumeration getHeaderNames() {
final Enumeration<String> headers = super.getHeaderNames();
List<String> httpHeaders = new ArrayList<String>();
while (headers.hasMoreElements()) {
String header = headers.nextElement();
if (!HttpHeaders.CONTENT_ENCODING.equalsIgnoreCase(header)) {
httpHeaders.add(header);
}
}
final Iterator<String> iterator = httpHeaders.iterator();
return new Enumeration<String>() {
public boolean hasMoreElements() {
return iterator.hasNext();
}
public String nextElement() {
return iterator.next();
}
};
}
}
}