/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.adapters.servlet;
import org.keycloak.adapters.spi.AdapterSessionStore;
import org.keycloak.adapters.spi.HttpFacade;
import org.keycloak.adapters.spi.KeycloakAccount;
import org.keycloak.common.util.Encode;
import org.keycloak.common.util.MultivaluedHashMap;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpSession;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.Principal;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class FilterSessionStore implements AdapterSessionStore {
public static final String REDIRECT_URI = "__REDIRECT_URI";
public static final String SAVED_METHOD = "__SAVED_METHOD";
public static final String SAVED_HEADERS = "__SAVED_HEADERS";
public static final String SAVED_BODY = "__SAVED_BODY";
protected final HttpServletRequest request;
protected final HttpFacade facade;
protected final int maxBuffer;
protected byte[] restoredBuffer = null;
protected boolean needRequestRestore;
public FilterSessionStore(HttpServletRequest request, HttpFacade facade, int maxBuffer) {
this.request = request;
this.facade = facade;
this.maxBuffer = maxBuffer;
}
public void clearSavedRequest(HttpSession session) {
session.removeAttribute(REDIRECT_URI);
session.removeAttribute(SAVED_METHOD);
session.removeAttribute(SAVED_HEADERS);
session.removeAttribute(SAVED_BODY);
}
public void servletRequestLogout() {
}
public static String getCharsetFromContentType(String contentType) {
if (contentType == null)
return (null);
int start = contentType.indexOf("charset=");
if (start < 0)
return (null);
String encoding = contentType.substring(start + 8);
int end = encoding.indexOf(';');
if (end >= 0)
encoding = encoding.substring(0, end);
encoding = encoding.trim();
if ((encoding.length() > 2) && (encoding.startsWith("\""))
&& (encoding.endsWith("\"")))
encoding = encoding.substring(1, encoding.length() - 1);
return (encoding.trim());
}
public HttpServletRequestWrapper buildWrapper(HttpSession session, final KeycloakAccount account) {
if (needRequestRestore) {
final String method = (String)session.getAttribute(SAVED_METHOD);
final byte[] body = (byte[])session.getAttribute(SAVED_BODY);
final MultivaluedHashMap<String, String> headers = (MultivaluedHashMap<String, String>)session.getAttribute(SAVED_HEADERS);
clearSavedRequest(session);
HttpServletRequestWrapper wrapper = new HttpServletRequestWrapper(request) {
protected MultivaluedHashMap<String, String> parameters;
MultivaluedHashMap<String, String> getParams() {
if (parameters != null) return parameters;
if (body == null) return new MultivaluedHashMap<String, String>();
String contentType = getContentType();
contentType = contentType.toLowerCase();
if (contentType.startsWith("application/x-www-form-urlencoded")) {
ByteArrayInputStream is = new ByteArrayInputStream(body);
try {
parameters = parseForm(is);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return parameters;
}
@Override
public boolean isUserInRole(String role) {
return account.getRoles().contains(role);
}
@Override
public Principal getUserPrincipal() {
return account.getPrincipal();
}
@Override
public String getMethod() {
if (needRequestRestore) {
return method;
} else {
return super.getMethod();
}
}
@Override
public String getHeader(String name) {
if (needRequestRestore && headers != null) {
return headers.getFirst(name.toLowerCase());
}
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaders(String name) {
if (needRequestRestore && headers != null) {
List<String> values = headers.getList(name.toLowerCase());
if (values == null) return Collections.emptyEnumeration();
else return Collections.enumeration(values);
}
return super.getHeaders(name);
}
@Override
public Enumeration<String> getHeaderNames() {
if (needRequestRestore && headers != null) {
return Collections.enumeration(headers.keySet());
}
return super.getHeaderNames();
}
@Override
public ServletInputStream getInputStream() throws IOException {
if (needRequestRestore && body != null) {
final ByteArrayInputStream is = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return is.read();
}
};
}
return super.getInputStream();
}
@Override
public void logout() throws ServletException {
servletRequestLogout();
}
@Override
public long getDateHeader(String name) {
if (!needRequestRestore) return super.getDateHeader(name);
return -1;
}
@Override
public int getIntHeader(String name) {
if (!needRequestRestore) return super.getIntHeader(name);
String value = getHeader(name);
if (value == null) return -1;
return Integer.valueOf(value);
}
@Override
public String[] getParameterValues(String name) {
if (!needRequestRestore) return super.getParameterValues(name);
MultivaluedHashMap<String, String> formParams = getParams();
if (formParams == null) {
return super.getParameterValues(name);
}
String[] values = request.getParameterValues(name);
List<String> list = new LinkedList<>();
if (values != null) {
for (String val : values) list.add(val);
}
List<String> vals = formParams.get(name);
if (vals != null) list.addAll(vals);
return list.toArray(new String[list.size()]);
}
@Override
public Enumeration<String> getParameterNames() {
if (!needRequestRestore) return super.getParameterNames();
MultivaluedHashMap<String, String> formParams = getParams();
if (formParams == null) {
return super.getParameterNames();
}
Set<String> names = new HashSet<>();
Enumeration<String> qnames = super.getParameterNames();
while (qnames.hasMoreElements()) names.add(qnames.nextElement());
names.addAll(formParams.keySet());
return Collections.enumeration(names);
}
@Override
public Map<String, String[]> getParameterMap() {
if (!needRequestRestore) return super.getParameterMap();
MultivaluedHashMap<String, String> formParams = getParams();
if (formParams == null) {
return super.getParameterMap();
}
Map<String, String[]> map = new HashMap<>();
Enumeration<String> names = getParameterNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
String[] values = getParameterValues(name);
if (values != null) {
map.put(name, values);
}
}
return map;
}
@Override
public String getParameter(String name) {
if (!needRequestRestore) return super.getParameter(name);
String param = super.getParameter(name);
if (param != null) return param;
MultivaluedHashMap<String, String> formParams = getParams();
if (formParams == null) {
return null;
}
return formParams.getFirst(name);
}
@Override
public BufferedReader getReader() throws IOException {
if (!needRequestRestore) return super.getReader();
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public int getContentLength() {
if (!needRequestRestore) return super.getContentLength();
String header = getHeader("content-length");
if (header == null) return -1;
return Integer.valueOf(header);
}
@Override
public String getContentType() {
if (!needRequestRestore) return super.getContentType();
return getHeader("content-type");
}
@Override
public String getCharacterEncoding() {
if (!needRequestRestore) return super.getCharacterEncoding();
return getCharsetFromContentType(getContentType());
}
};
return wrapper;
} else {
return new HttpServletRequestWrapper(request) {
@Override
public boolean isUserInRole(String role) {
return account.getRoles().contains(role);
}
@Override
public Principal getUserPrincipal() {
if (account == null) return null;
return account.getPrincipal();
}
@Override
public void logout() throws ServletException {
servletRequestLogout();
}
};
}
}
public String getRedirectUri() {
HttpSession session = request.getSession(true);
return (String)session.getAttribute(REDIRECT_URI);
}
@Override
public boolean restoreRequest() {
HttpSession session = request.getSession(false);
if (session == null) return false;
return session.getAttribute(REDIRECT_URI) != null;
}
public static MultivaluedHashMap<String, String> parseForm(InputStream entityStream)
throws IOException
{
char[] buffer = new char[100];
StringBuffer buf = new StringBuffer();
BufferedReader reader = new BufferedReader(new InputStreamReader(entityStream));
int wasRead = 0;
do
{
wasRead = reader.read(buffer, 0, 100);
if (wasRead > 0) buf.append(buffer, 0, wasRead);
} while (wasRead > -1);
String form = buf.toString();
MultivaluedHashMap<String, String> formData = new MultivaluedHashMap<String, String>();
if ("".equals(form)) return formData;
String[] params = form.split("&");
for (String param : params)
{
if (param.indexOf('=') >= 0)
{
String[] nv = param.split("=");
String val = nv.length > 1 ? nv[1] : "";
formData.add(Encode.decode(nv[0]), Encode.decode(val));
}
else
{
formData.add(Encode.decode(param), "");
}
}
return formData;
}
@Override
public void saveRequest() {
HttpSession session = request.getSession(true);
session.setAttribute(REDIRECT_URI, facade.getRequest().getURI());
session.setAttribute(SAVED_METHOD, request.getMethod());
MultivaluedHashMap<String, String> headers = new MultivaluedHashMap<>();
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
Enumeration<String> values = request.getHeaders(name);
while (values.hasMoreElements()) {
headers.add(name.toLowerCase(), values.nextElement());
}
}
session.setAttribute(SAVED_HEADERS, headers);
if (request.getMethod().equalsIgnoreCase("GET")) {
return;
}
ByteArrayOutputStream os = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int bytesRead;
int totalRead = 0;
try {
InputStream is = request.getInputStream();
while ( (bytesRead = is.read(buffer) ) >= 0) {
os.write(buffer);
totalRead += bytesRead;
if (totalRead > maxBuffer) {
throw new RuntimeException("max buffer reached on a saved request");
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
byte[] body = os.toByteArray();
// Only save the request body if there is something to save
if (body.length > 0) {
session.setAttribute(SAVED_BODY, body);
}
}
}