package org.jboss.seam.web;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.rmi.server.UID;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
/**
* Request wrapper for supporting multipart requests, used for file uploading.
*
* @author Shane Bryzak
*/
public class MultipartRequestImpl extends HttpServletRequestWrapper implements MultipartRequest
{
private static final String PARAM_NAME = "name";
private static final String PARAM_FILENAME = "filename";
private static final String PARAM_CONTENT_TYPE = "Content-Type";
private static final int BUFFER_SIZE = 2048;
private static final int CHUNK_SIZE = 512;
private boolean createTempFiles;
private String encoding = null;
private Map<String,Param> parameters = null;
private enum ReadState { BOUNDARY, HEADERS, DATA }
private static final byte CR = 0x0d;
private static final byte LF = 0x0a;
private static final byte[] CR_LF = {CR,LF};
private abstract class Param
{
private String name;
public Param(String name)
{
this.name = name;
}
public String getName()
{
return name;
}
public abstract void appendData(byte[] data, int start, int length)
throws IOException;
}
private class ValueParam extends Param
{
private Object value = null;
private ByteArrayOutputStream buf = new ByteArrayOutputStream();
public ValueParam(String name)
{
super(name);
}
@Override
public void appendData(byte[] data, int start, int length)
throws IOException
{
buf.write(data, start, length);
}
public void complete()
throws UnsupportedEncodingException
{
String val = encoding == null ? new String(buf.toByteArray()) :
new String(buf.toByteArray(), encoding);
if (value == null)
{
value = val;
}
else
{
if (!(value instanceof List))
{
List<String> v = new ArrayList<String>();
v.add((String) value);
value = v;
}
((List) value).add(val);
}
buf.reset();
}
public Object getValue()
{
return value;
}
}
private class FileParam extends Param
{
private String filename;
private String contentType;
private int fileSize;
private ByteArrayOutputStream bOut = null;
private FileOutputStream fOut = null;
private File tempFile = null;
public FileParam(String name)
{
super(name);
}
public String getFilename()
{
return filename;
}
public void setFilename(String filename)
{
this.filename = filename;
}
public String getContentType()
{
return contentType;
}
public void setContentType(String contentType)
{
this.contentType = contentType;
}
public int getFileSize()
{
return fileSize;
}
public void createTempFile()
{
try
{
tempFile = File.createTempFile(new UID().toString().replace(":", "-"), ".upload");
tempFile.deleteOnExit();
fOut = new FileOutputStream(tempFile);
}
catch (IOException ex)
{
throw new FileUploadException("Could not create temporary file");
}
}
@Override
public void appendData(byte[] data, int start, int length)
throws IOException
{
if (fOut != null)
{
fOut.write(data, start, length);
fOut.flush();
}
else
{
if (bOut == null) bOut = new ByteArrayOutputStream();
bOut.write(data, start, length);
}
fileSize += length;
}
public byte[] getData()
{
if (fOut != null)
{
try
{
fOut.close();
}
catch (IOException ex) {}
fOut = null;
}
if (bOut != null)
{
return bOut.toByteArray();
}
else if (tempFile != null)
{
if (tempFile.exists())
{
try
{
FileInputStream fIn = new FileInputStream(tempFile);
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
byte[] buf = new byte[512];
int read = fIn.read(buf);
while (read != -1)
{
bOut.write(buf, 0, read);
read = fIn.read(buf);
}
bOut.flush();
fIn.close();
tempFile.delete();
return bOut.toByteArray();
}
catch (IOException ex) { /* too bad? */}
}
}
return null;
}
public InputStream getInputStream()
{
if (fOut != null)
{
try
{
fOut.close();
}
catch (IOException ex) {}
fOut = null;
}
if (bOut!=null)
{
return new ByteArrayInputStream(bOut.toByteArray());
}
else if (tempFile!=null)
{
try
{
return new FileInputStream(tempFile) {
@Override
public void close() throws IOException
{
super.close();
tempFile.delete();
}
};
}
catch (FileNotFoundException ex) { }
}
return null;
}
}
private HttpServletRequest request;
public MultipartRequestImpl(HttpServletRequest request, boolean createTempFiles,
int maxRequestSize)
{
super(request);
this.request = request;
this.createTempFiles = createTempFiles;
String contentLength = request.getHeader("Content-Length");
if (contentLength != null && maxRequestSize > 0 &&
Integer.parseInt(contentLength) > maxRequestSize)
{
throw new FileUploadException("Multipart request is larger than allowed size");
}
}
private void parseRequest()
{
byte[] boundaryMarker = getBoundaryMarker(request.getContentType());
if (boundaryMarker == null)
{
throw new FileUploadException("The request was rejected because "
+ "no multipart boundary was found");
}
encoding = request.getCharacterEncoding();
parameters = new HashMap<String,Param>();
try
{
byte[] buffer = new byte[BUFFER_SIZE];
Map<String,String> headers = new HashMap<String,String>();
ReadState readState = ReadState.BOUNDARY;
InputStream input = request.getInputStream();
int read = input.read(buffer);
int pos = 0;
Param p = null;
// This is a fail-safe to prevent infinite loops from occurring in some environments
int loopCounter = 20;
while (read > 0 && loopCounter > 0)
{
for (int i = 0; i < read; i++)
{
switch (readState)
{
case BOUNDARY:
{
if (checkSequence(buffer, i, boundaryMarker) && checkSequence(buffer, i + 2, CR_LF))
{
readState = ReadState.HEADERS;
i += 2;
pos = i + 1;
}
break;
}
case HEADERS:
{
if (checkSequence(buffer, i, CR_LF))
{
String param = (encoding == null) ?
new String(buffer, pos, i - pos - 1) :
new String(buffer, pos, i - pos - 1, encoding);
parseParams(param, ";", headers);
if (checkSequence(buffer, i + CR_LF.length, CR_LF))
{
readState = ReadState.DATA;
i += CR_LF.length;
pos = i + 1;
String paramName = headers.get(PARAM_NAME);
if (paramName != null)
{
if (headers.containsKey(PARAM_FILENAME))
{
FileParam fp = new FileParam(paramName);
if (createTempFiles) fp.createTempFile();
fp.setContentType(headers.get(PARAM_CONTENT_TYPE));
fp.setFilename(headers.get(PARAM_FILENAME));
p = fp;
}
else
{
if (parameters.containsKey(paramName))
{
p = parameters.get(paramName);
}
else
{
p = new ValueParam(paramName);
}
}
if (!parameters.containsKey(paramName))
{
parameters.put(paramName, p);
}
}
headers.clear();
}
else
{
pos = i + 1;
}
}
break;
}
case DATA:
{
// If we've encountered another boundary...
if (checkSequence(buffer, i - boundaryMarker.length - CR_LF.length, CR_LF) &&
checkSequence(buffer, i, boundaryMarker))
{
// Write any data before the boundary (that hasn't already been written) to the param
if (pos < i - boundaryMarker.length - CR_LF.length - 1)
{
p.appendData(buffer, pos, i - pos - boundaryMarker.length - CR_LF.length - 1);
}
if (p instanceof ValueParam) ((ValueParam) p).complete();
if (checkSequence(buffer, i + CR_LF.length, CR_LF))
{
i += CR_LF.length;
pos = i + 1;
}
else
{
pos = i;
}
readState = ReadState.HEADERS;
}
// Otherwise write whatever data we have to the param
else if (i > (pos + boundaryMarker.length + CHUNK_SIZE + CR_LF.length))
{
p.appendData(buffer, pos, CHUNK_SIZE);
pos += CHUNK_SIZE;
}
break;
}
}
}
if (pos < read)
{
// move the bytes that weren't read to the start of the buffer
int bytesNotRead = read - pos;
System.arraycopy(buffer, pos, buffer, 0, bytesNotRead);
read = input.read(buffer, bytesNotRead, buffer.length - bytesNotRead);
// Decrement loopCounter if no data was readable
if (read == 0)
{
loopCounter--;
}
read += bytesNotRead;
}
else
{
read = input.read(buffer);
}
pos = 0;
}
}
catch (IOException ex)
{
throw new FileUploadException("IO Error parsing multipart request", ex);
}
}
private byte[] getBoundaryMarker(String contentType)
{
Map<String, Object> params = parseParams(contentType, ";");
String boundaryStr = (String) params.get("boundary");
if (boundaryStr == null) return null;
try
{
return boundaryStr.getBytes("ISO-8859-1");
}
catch (UnsupportedEncodingException e)
{
return boundaryStr.getBytes();
}
}
/**
* Checks if a specified sequence of bytes ends at a specific position
* within a byte array.
*
* @param data
* @param pos
* @param seq
* @return boolean indicating if the sequence was found at the specified position
*/
private boolean checkSequence(byte[] data, int pos, byte[] seq)
{
if (pos - seq.length < -1 || pos >= data.length)
return false;
for (int i = 0; i < seq.length; i++)
{
if (data[(pos - seq.length) + i + 1] != seq[i])
return false;
}
return true;
}
private static final Pattern PARAM_VALUE_PATTERN = Pattern
.compile("^\\s*([^\\s=]+)\\s*[=:]\\s*(.+)\\s*$");
private Map parseParams(String paramStr, String separator)
{
Map<String,String> paramMap = new HashMap<String, String>();
parseParams(paramStr, separator, paramMap);
return paramMap;
}
private void parseParams(String paramStr, String separator, Map paramMap)
{
String[] parts = paramStr.split("[" + separator + "]");
for (String part : parts)
{
Matcher m = PARAM_VALUE_PATTERN.matcher(part);
if (m.matches())
{
String key = m.group(1);
String value = m.group(2);
// Strip double quotes
if (value.startsWith("\"") && value.endsWith("\""))
value = value.substring(1, value.length() - 1);
paramMap.put(key, value);
}
}
}
private Param getParam(String name)
{
if (parameters == null)
parseRequest();
return parameters.get(name);
}
@Override
public Enumeration getParameterNames()
{
if (parameters == null)
parseRequest();
return Collections.enumeration(parameters.keySet());
}
public byte[] getFileBytes(String name)
{
Param p = getParam(name);
return (p != null && p instanceof FileParam) ?
((FileParam) p).getData() : null;
}
public InputStream getFileInputStream(String name)
{
Param p = getParam(name);
return (p != null && p instanceof FileParam) ?
((FileParam) p).getInputStream() : null;
}
public String getFileContentType(String name)
{
Param p = getParam(name);
return (p != null && p instanceof FileParam) ?
((FileParam) p).getContentType() : null;
}
public String getFileName(String name)
{
Param p = getParam(name);
return (p != null && p instanceof FileParam) ?
((FileParam) p).getFilename() : null;
}
public int getFileSize(String name)
{
Param p = getParam(name);
return (p != null && p instanceof FileParam) ?
((FileParam) p).getFileSize() : -1;
}
@Override
public String getParameter(String name)
{
Param p = getParam(name);
if (p != null && p instanceof ValueParam)
{
ValueParam vp = (ValueParam) p;
if (vp.getValue() instanceof String) return (String) vp.getValue();
}
else if (p != null && p instanceof FileParam)
{
return "---BINARY DATA---";
}
else
{
return super.getParameter(name);
}
return null;
}
@Override
public String[] getParameterValues(String name)
{
Param p = getParam(name);
if (p != null && p instanceof ValueParam)
{
ValueParam vp = (ValueParam) p;
if (vp.getValue() instanceof List)
{
List vals = (List) vp.getValue();
String[] values = new String[vals.size()];
vals.toArray(values);
return values;
}
else
{
return new String[] {(String) vp.getValue()};
}
}
else
{
return super.getParameterValues(name);
}
}
@Override
public Map getParameterMap()
{
if (parameters == null)
parseRequest();
Map<String,Object> params = new HashMap<String,Object>(super.getParameterMap());
for (String name : parameters.keySet())
{
Param p = parameters.get(name);
if (p instanceof ValueParam)
{
ValueParam vp = (ValueParam) p;
if (vp.getValue() instanceof String)
{
params.put(name, vp.getValue());
}
else if (vp.getValue() instanceof List)
{
params.put(name, getParameterValues(name));
}
}
}
return params;
}
}