package de.is24.infrastructure.gridfs.http.web.filter;
import de.is24.infrastructure.gridfs.http.exception.BadRequestException;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.server.ServletServerHttpRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static java.util.Arrays.asList;
import static java.util.Collections.enumeration;
import static java.util.Collections.list;
import static java.util.Collections.unmodifiableMap;
import static org.apache.commons.lang.ArrayUtils.addAll;
public class FormEncodedHttpServletRequestWrapper extends HttpServletRequestWrapper {
public static final String CONTENT_TYPE = "Content-Type";
public static final String CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded";
private Map<String, String[]> parameters;
public FormEncodedHttpServletRequestWrapper(HttpServletRequest request, FormHttpMessageConverter messageConverter) {
super(request);
this.parameters = mergeParameters(messageConverter);
}
@Override
public String getHeader(String name) {
if (CONTENT_TYPE.equals(name)) {
return CONTENT_TYPE_FORM_URLENCODED;
}
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaderNames() {
Set<String> headerNames = new LinkedHashSet<>(list(super.getHeaderNames()));
headerNames.add(CONTENT_TYPE);
return enumeration(headerNames);
}
@Override
public Enumeration<String> getHeaders(String name) {
if (CONTENT_TYPE.equals(name)) {
return enumeration(asList(CONTENT_TYPE_FORM_URLENCODED));
}
return super.getHeaders(name);
}
@Override
public String getParameter(String name) {
String[] value = parameters.get(name);
if ((value != null) && (value.length > 0)) {
return value[0];
}
return null;
}
@Override
public Map<String, String[]> getParameterMap() {
return parameters;
}
@Override
public Enumeration<String> getParameterNames() {
return enumeration(parameters.keySet());
}
@Override
public String[] getParameterValues(String name) {
return parameters.get(name);
}
private Map<String, String[]> mergeParameters(FormHttpMessageConverter messageConverter) {
Map<String, String[]> mergedParameterMap = new HashMap<>(super.getParameterMap());
mergeParameters(mergedParameterMap, parseRequestBody(messageConverter));
return unmodifiableMap(mergedParameterMap);
}
private void mergeParameters(Map<String, String[]> mergedParameterMap,
Set<Map.Entry<String, List<String>>> requestBodyParameters) {
for (Map.Entry<String, List<String>> entry : requestBodyParameters) {
String[] existingParameterValue = mergedParameterMap.get(entry.getKey());
List<String> value = entry.getValue();
if (existingParameterValue == null) {
mergedParameterMap.put(entry.getKey(), value.toArray(new String[value.size()]));
} else {
String[] newParameterValue = (String[]) addAll(existingParameterValue, entry.getValue().toArray(new String[value.size()]));
mergedParameterMap.put(entry.getKey(), newParameterValue);
}
}
}
private Set<Map.Entry<String, List<String>>> parseRequestBody(FormHttpMessageConverter messageConverter) {
try {
ServletServerHttpRequest inputMessage = new ServletServerHttpRequest(this) {
@Override
public InputStream getBody() throws IOException {
return FormEncodedHttpServletRequestWrapper.this.getInputStream();
}
};
return messageConverter.read(null, inputMessage).entrySet();
} catch (IOException e) {
throw new BadRequestException("Could not parse form url encoded request.", e);
}
}
}