/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.usergrid.rest.filters;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PushbackInputStream;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.ws.rs.HttpMethod;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
/**
* Filter for setting default accept and Content-Type as application/json when undefined by client
*
* @author tnine
*/
public class ContentTypeFilter implements Filter {
private static final Logger logger = LoggerFactory.getLogger( ContentTypeFilter.class );
/*
* (non-Javadoc)
*
* @see javax.servlet.Filter#init(javax.servlet.FilterConfig)
*/
@Override
public void init( FilterConfig filterConfig ) throws ServletException {
if (logger.isTraceEnabled()) {
logger.trace("Starting content type filter");
}
}
/*
* (non-Javadoc)
*
* @see javax.servlet.Filter#doFilter(javax.servlet.ServletRequest,
* javax.servlet.ServletResponse, javax.servlet.FilterChain)
*/
@Override
public void doFilter( ServletRequest request, ServletResponse response, FilterChain chain )
throws IOException, ServletException {
if ( !( request instanceof HttpServletRequest ) ) {
chain.doFilter( request, response );
return;
}
HttpServletRequest origRequest = ( HttpServletRequest ) request;
HeaderWrapperRequest newRequest = new HeaderWrapperRequest( origRequest );
newRequest.adapt();
try {
chain.doFilter( newRequest, response );
} catch ( Exception e ) {
logger.error("Error in filter", e);
}
}
/*
* (non-Javadoc)
*
* @see javax.servlet.Filter#destroy()
*/
@Override
public void destroy() {
}
private class HeaderWrapperRequest extends HttpServletRequestWrapper {
private PushbackInputStream inputStream = null;
private ServletInputStream servletInputStream = null;
private HttpServletRequest origRequest = null;
private BufferedReader reader = null;
private final Map<String, String> newHeaders = new HashMap<String, String>();
/**
* @param request
* @throws IOException
*/
public HeaderWrapperRequest( HttpServletRequest request ) throws IOException {
super( request );
origRequest = request;
inputStream = new PushbackInputStream( request.getInputStream() );
servletInputStream = new DelegatingServletInputStream( inputStream );
}
/**
* @throws IOException
*
*/
private void adapt() throws IOException {
String path = origRequest.getRequestURI();
String method = origRequest.getMethod();
if (logger.isTraceEnabled()) {
logger.trace("Content path is '{}'", path);
}
// first ensure that an Accept header is set
@SuppressWarnings( "rawtypes" ) Enumeration acceptHeaders = origRequest.getHeaders( HttpHeaders.ACCEPT );
if ( !acceptHeaders.hasMoreElements() ) {
setHeader( HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON );
}
// next, ensure that one and only one content-type is set
int initial = inputStream.read();
if ( initial == -1 ) {
// request has no body, set type to application/json
if ( ( HttpMethod.POST.equals( method ) || HttpMethod.PUT.equals( method ) )
&& !MediaType.APPLICATION_FORM_URLENCODED.equals( getContentType() ) ) {
if (logger.isTraceEnabled()) {
logger.trace("Setting content type to application/json for POST or PUT with no content at path '{}'", path);
}
setHeader( HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON );
setHeader( HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON );
}
return;
}
char firstChar = ( char ) initial;
if ( ( firstChar == '{' || firstChar == '[' )
&& !MediaType.APPLICATION_JSON.equals( getContentType() )) {
// request appears to be JSON so set type to application/json
if (logger.isTraceEnabled()) {
logger.trace("Setting content type to application/json for POST or PUT with json content at path '{}'", path);
}
setHeader( HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON );
setHeader( HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON );
}
inputStream.unread( initial );
}
public void setHeader( String name, String value ) {
newHeaders.put( name.toLowerCase(), value );
}
/*
* (non-Javadoc)
*
* @see
* javax.servlet.http.HttpServletRequestWrapper#getHeader(java.lang.
* String)
*/
@Override
public String getHeader( String name ) {
String header = newHeaders.get( name );
if ( header != null ) {
return header;
}
return super.getHeader( name );
}
/*
* (non-Javadoc)
*
* @see
* javax.servlet.http.HttpServletRequestWrapper#getHeaders(java.lang
* .String)
*/
@Override
public Enumeration getHeaders( String name ) {
Set<String> headers = new LinkedHashSet<String>();
String overridden = newHeaders.get( name );
if ( overridden != null ) {
headers.add( overridden );
return Collections.enumeration( headers );
}
return super.getHeaders( name );
}
/*
* (non-Javadoc)
*
* @see javax.servlet.http.HttpServletRequestWrapper#getHeaderNames()
*/
@Override
public Enumeration getHeaderNames() {
Set<String> headers = new LinkedHashSet<String>();
for ( Enumeration e = super.getHeaderNames(); e.hasMoreElements(); ) {
headers.add( e.nextElement().toString() );
}
headers.addAll( newHeaders.keySet() );
return Collections.enumeration( headers );
}
/*
* (non-Javadoc)
*
* @see javax.servlet.ServletRequestWrapper#getInputStream()
*/
@Override
public ServletInputStream getInputStream() throws IOException {
return servletInputStream;
}
/*
* (non-Javadoc)
*
* @see javax.servlet.ServletRequestWrapper#getReader()
*/
@Override
public BufferedReader getReader() throws IOException {
if ( reader != null ) {
return reader;
}
reader = new BufferedReader( new InputStreamReader( servletInputStream ) );
return reader;
}
// NOTE, for full override we need to implement the other getHeader
// methods. We won't use it, so I'm not implementing it here
}
/**
* Delegating implementation of {@link javax.servlet.ServletInputStream}.
*
* @author Juergen Hoeller, Todd Nine
* @since 1.0.2
*/
private static class DelegatingServletInputStream extends ServletInputStream {
private final InputStream sourceStream;
/**
* Create a DelegatingServletInputStream for the given source stream.
*
* @param sourceStream the source stream (never <code>null</code>)
*/
public DelegatingServletInputStream( InputStream sourceStream ) {
Assert.notNull( sourceStream, "Source InputStream must not be null" );
this.sourceStream = sourceStream;
}
/** Return the underlying source stream (never <code>null</code>). */
public final InputStream getSourceStream() {
return this.sourceStream;
}
public int read() throws IOException {
return this.sourceStream.read();
}
public void close() throws IOException {
super.close();
this.sourceStream.close();
}
}
}