/* * Copyright 2013-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.netflix.zuul.filters.pre; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.lang.reflect.Field; import javax.servlet.ServletInputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletRequestWrapper; import javax.servlet.http.HttpServletRequest; import org.springframework.cloud.netflix.zuul.util.RequestContentDataExtractor; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpOutputMessage; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.ReflectionUtils; import org.springframework.web.servlet.DispatcherServlet; import com.netflix.zuul.ZuulFilter; import com.netflix.zuul.context.RequestContext; import com.netflix.zuul.http.HttpServletRequestWrapper; import com.netflix.zuul.http.ServletInputStreamWrapper; import static org.springframework.cloud.netflix.zuul.filters.support.FilterConstants.FORM_BODY_WRAPPER_FILTER_ORDER; import static org.springframework.cloud.netflix.zuul.filters.support.FilterConstants.PRE_TYPE; /** * Pre {@link ZuulFilter} that parses form data and reencodes it for downstream services * * @author Dave Syer */ public class FormBodyWrapperFilter extends ZuulFilter { private FormHttpMessageConverter formHttpMessageConverter; private Field requestField; private Field servletRequestField; public FormBodyWrapperFilter() { this(new AllEncompassingFormHttpMessageConverter()); } public FormBodyWrapperFilter(FormHttpMessageConverter formHttpMessageConverter) { this.formHttpMessageConverter = formHttpMessageConverter; this.requestField = ReflectionUtils.findField(HttpServletRequestWrapper.class, "req", HttpServletRequest.class); this.servletRequestField = ReflectionUtils.findField(ServletRequestWrapper.class, "request", ServletRequest.class); Assert.notNull(this.requestField, "HttpServletRequestWrapper.req field not found"); Assert.notNull(this.servletRequestField, "ServletRequestWrapper.request field not found"); this.requestField.setAccessible(true); this.servletRequestField.setAccessible(true); } @Override public String filterType() { return PRE_TYPE; } @Override public int filterOrder() { return FORM_BODY_WRAPPER_FILTER_ORDER; } @Override public boolean shouldFilter() { RequestContext ctx = RequestContext.getCurrentContext(); HttpServletRequest request = ctx.getRequest(); String contentType = request.getContentType(); // Don't use this filter on GET method if (contentType == null) { return false; } // Only use this filter for form data and only for multipart data in a // DispatcherServlet handler try { MediaType mediaType = MediaType.valueOf(contentType); return MediaType.APPLICATION_FORM_URLENCODED.includes(mediaType) || (isDispatcherServletRequest(request) && MediaType.MULTIPART_FORM_DATA.includes(mediaType)); } catch (InvalidMediaTypeException ex) { return false; } } private boolean isDispatcherServletRequest(HttpServletRequest request) { return request.getAttribute( DispatcherServlet.WEB_APPLICATION_CONTEXT_ATTRIBUTE) != null; } @Override public Object run() { RequestContext ctx = RequestContext.getCurrentContext(); HttpServletRequest request = ctx.getRequest(); FormBodyRequestWrapper wrapper = null; if (request instanceof HttpServletRequestWrapper) { HttpServletRequest wrapped = (HttpServletRequest) ReflectionUtils .getField(this.requestField, request); wrapper = new FormBodyRequestWrapper(wrapped); ReflectionUtils.setField(this.requestField, request, wrapper); if (request instanceof ServletRequestWrapper) { ReflectionUtils.setField(this.servletRequestField, request, wrapper); } } else { wrapper = new FormBodyRequestWrapper(request); ctx.setRequest(wrapper); } if (wrapper != null) { ctx.getZuulRequestHeaders().put("content-type", wrapper.getContentType()); } return null; } private class FormBodyRequestWrapper extends Servlet30RequestWrapper { private HttpServletRequest request; private byte[] contentData; private MediaType contentType; private int contentLength; public FormBodyRequestWrapper(HttpServletRequest request) { super(request); this.request = request; } @Override public String getContentType() { if (this.contentData == null) { buildContentData(); } return this.contentType.toString(); } @Override public int getContentLength() { if (super.getContentLength() <= 0) { return super.getContentLength(); } if (this.contentData == null) { buildContentData(); } return this.contentLength; } public long getContentLengthLong() { return getContentLength(); } @Override public ServletInputStream getInputStream() throws IOException { if (this.contentData == null) { buildContentData(); } return new ServletInputStreamWrapper(this.contentData); } private synchronized void buildContentData() { try { MultiValueMap<String, Object> builder = RequestContentDataExtractor.extract(this.request); FormHttpOutputMessage data = new FormHttpOutputMessage(); this.contentType = MediaType.valueOf(this.request.getContentType()); data.getHeaders().setContentType(this.contentType); FormBodyWrapperFilter.this.formHttpMessageConverter.write(builder, this.contentType, data); // copy new content type including multipart boundary this.contentType = data.getHeaders().getContentType(); this.contentData = data.getInput(); this.contentLength = this.contentData.length; } catch (Exception e) { throw new IllegalStateException("Cannot convert form data", e); } } private class FormHttpOutputMessage implements HttpOutputMessage { private HttpHeaders headers = new HttpHeaders(); private ByteArrayOutputStream output = new ByteArrayOutputStream(); @Override public HttpHeaders getHeaders() { return this.headers; } @Override public OutputStream getBody() throws IOException { return this.output; } public byte[] getInput() throws IOException { this.output.flush(); return this.output.toByteArray(); } } } }