/*******************************************************************************
* Copyright (c) 2010 Trustwave Holdings, Inc.
*******************************************************************************/
package com.trustwave.deface.utils;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import javax.faces.application.StateManager;
import javax.faces.context.FacesContext;
import javax.faces.context.ResponseWriter;
public class WriteBehindStateWriter extends Writer
{
// length of the state marker
private static final String SAVESTATE_FIELD_MARKER = "~deface.field.marker~";
private static final int STATE_MARKER_LEN = SAVESTATE_FIELD_MARKER.length();
public static final String SAVESTATE_FIELD_DELIMITER = "~";
private static final ThreadLocal<WriteBehindStateWriter> CUR_WRITER =
new ThreadLocal<WriteBehindStateWriter>();
private Writer out;
private Writer orig;
private StringWriter fWriter;
private boolean stateWritten;
private int bufSize;
private char[] buf;
private FacesContext context;
// -------------------------------------------------------- Constructors
public WriteBehindStateWriter(FacesContext context, int bufSize) {
this.out = new StringWriter();
this.orig = out;
this.context = context;
this.bufSize = bufSize;
this.buf = new char[bufSize];
CUR_WRITER.set(this);
}
// ------------------------------------------------- Methods from Writer
public String toString()
{
return out.toString();
}
public void write(int c) throws IOException {
out.write(c);
}
public void write(char cbuf[]) throws IOException {
out.write(cbuf);
}
public void write(String str) throws IOException {
out.write(str);
}
public void write(String str, int off, int len) throws IOException {
out.write(str, off, len);
}
public void write(char cbuf[], int off, int len) throws IOException {
out.write(cbuf, off, len);
}
public void flush() throws IOException {
// no-op
}
public void close() throws IOException {
// no-op
}
// ------------------------------------------------------ Public Methods
public static WriteBehindStateWriter getCurrentInstance() {
return CUR_WRITER.get();
}
public void release() {
CUR_WRITER.remove();
}
public void writingState() {
if (!stateWritten) {
this.stateWritten = true;
out = fWriter = new StringWriter(1024);
}
}
public boolean stateWritten() {
return stateWritten;
}
/**
* <p> Write directly from our FastStringWriter to the provided
* writer.</p>
* @throws IOException if an error occurs
*/
public void flushToWriter() throws IOException {
// Save the state to a new instance of StringWriter to
// avoid multiple serialization steps if the view contains
// multiple forms.
StateManager stateManager = context.getApplication().getStateManager();
ResponseWriter origWriter = context.getResponseWriter();
StringWriter state =
new StringWriter((stateManager.isSavingStateInClient(
context)) ? bufSize : 128);
context.setResponseWriter(origWriter.cloneWithWriter(state));
stateManager.writeState(context, stateManager.saveView(context));
context.setResponseWriter(origWriter);
StringBuffer builder = fWriter.getBuffer();
// begin writing...
int totalLen = builder.length();
StringBuffer stateBuilder = state.getBuffer();
int stateLen = stateBuilder.length();
int pos = 0;
int tildeIdx = getNextDelimiterIndex(builder, pos);
while (pos < totalLen) {
if (tildeIdx != -1) {
if (tildeIdx > pos && (tildeIdx - pos) > bufSize) {
// there's enough content before the first ~
// to fill the entire buffer
builder.getChars(pos, (pos + bufSize), buf, 0);
orig.write(buf);
pos += bufSize;
} else {
// write all content up to the first '~'
builder.getChars(pos, tildeIdx, buf, 0);
int len = (tildeIdx - pos);
orig.write(buf, 0, len);
// now check to see if the state saving string is
// at the begining of pos, if so, write our
// state out.
if (builder.indexOf(
SAVESTATE_FIELD_MARKER,
pos) == tildeIdx) {
// buf is effectively zero'd out at this point
int statePos = 0;
while (statePos < stateLen) {
if ((stateLen - statePos) > bufSize) {
// enough state to fill the buffer
stateBuilder.getChars(statePos,
(statePos + bufSize),
buf,
0);
orig.write(buf);
statePos += bufSize;
} else {
int slen = (stateLen - statePos);
stateBuilder.getChars(statePos,
stateLen,
buf,
0);
orig.write(buf, 0, slen);
statePos += slen;
}
}
// push us past the last '~' at the end of the marker
pos += (len + STATE_MARKER_LEN);
tildeIdx = getNextDelimiterIndex(builder, pos);
} else {
pos = tildeIdx;
tildeIdx = getNextDelimiterIndex(builder,
tildeIdx + 1);
}
}
} else {
// we've written all of the state field markers.
// finish writing content
if (totalLen - pos > bufSize) {
// there's enough content to fill the buffer
builder.getChars(pos, (pos + bufSize), buf, 0);
orig.write(buf);
pos += bufSize;
} else {
// we're near the end of the response
builder.getChars(pos, totalLen, buf, 0);
int len = (totalLen - pos);
orig.write(buf, 0, len);
pos += (len + 1);
}
}
}
// all state has been written. Have 'out' point to the
// response so that all subsequent writes will make it to the
// browser.
out = orig;
}
private static int getNextDelimiterIndex(StringBuffer builder,
int offset) {
return builder.indexOf(SAVESTATE_FIELD_DELIMITER,
offset);
}
}