/*
* Copyright 2002-2008 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.multipart.commons;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import junit.framework.TestCase;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemFactory;
import org.apache.commons.fileupload.FileUpload;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockServletContext;
import org.springframework.mock.web.PassThroughFilterChain;
import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.support.ByteArrayMultipartFileEditor;
import org.springframework.web.multipart.support.MultipartFilter;
import org.springframework.web.multipart.support.StringMultipartFileEditor;
import org.springframework.web.util.WebUtils;
/**
* @author Juergen Hoeller
* @since 08.10.2003
*/
public class CommonsMultipartResolverTests extends TestCase {
public void testWithApplicationContext() throws Exception {
doTestWithApplicationContext(false);
}
public void testWithApplicationContextAndLazyResolution() throws Exception {
doTestWithApplicationContext(true);
}
private void doTestWithApplicationContext(boolean lazy) throws Exception {
StaticWebApplicationContext wac = new StaticWebApplicationContext();
wac.setServletContext(new MockServletContext());
wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp"));
wac.refresh();
MockCommonsMultipartResolver resolver = new MockCommonsMultipartResolver();
resolver.setMaxUploadSize(1000);
resolver.setMaxInMemorySize(100);
resolver.setDefaultEncoding("enc");
if (lazy) {
resolver.setResolveLazily(false);
}
resolver.setServletContext(wac.getServletContext());
assertEquals(1000, resolver.getFileUpload().getSizeMax());
assertEquals(100, resolver.getFileItemFactory().getSizeThreshold());
assertEquals("enc", resolver.getFileUpload().getHeaderEncoding());
assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp"));
MockHttpServletRequest originalRequest = new MockHttpServletRequest();
originalRequest.setMethod("POST");
originalRequest.setContentType("multipart/form-data");
originalRequest.addHeader("Content-type", "multipart/form-data");
originalRequest.addParameter("getField", "getValue");
assertTrue(resolver.isMultipart(originalRequest));
MultipartHttpServletRequest request = resolver.resolveMultipart(originalRequest);
Set parameterNames = new HashSet();
Enumeration parameterEnum = request.getParameterNames();
while (parameterEnum.hasMoreElements()) {
parameterNames.add(parameterEnum.nextElement());
}
assertEquals(3, parameterNames.size());
assertTrue(parameterNames.contains("field3"));
assertTrue(parameterNames.contains("field4"));
assertTrue(parameterNames.contains("getField"));
assertEquals("value3", request.getParameter("field3"));
List parameterValues = Arrays.asList(request.getParameterValues("field3"));
assertEquals(1, parameterValues.size());
assertTrue(parameterValues.contains("value3"));
assertEquals("value4", request.getParameter("field4"));
parameterValues = Arrays.asList(request.getParameterValues("field4"));
assertEquals(2, parameterValues.size());
assertTrue(parameterValues.contains("value4"));
assertTrue(parameterValues.contains("value5"));
assertEquals("value4", request.getParameter("field4"));
assertEquals("getValue", request.getParameter("getField"));
List parameterMapKeys = new ArrayList();
List parameterMapValues = new ArrayList();
for (Iterator parameterMapIter = request.getParameterMap().keySet().iterator(); parameterMapIter.hasNext();) {
String key = (String) parameterMapIter.next();
parameterMapKeys.add(key);
parameterMapValues.add(request.getParameterMap().get(key));
}
assertEquals(3, parameterMapKeys.size());
assertEquals(3, parameterMapValues.size());
int field3Index = parameterMapKeys.indexOf("field3");
int field4Index = parameterMapKeys.indexOf("field4");
int getFieldIndex = parameterMapKeys.indexOf("getField");
assertTrue(field3Index != -1);
assertTrue(field4Index != -1);
assertTrue(getFieldIndex != -1);
parameterValues = Arrays.asList((String[]) parameterMapValues.get(field3Index));
assertEquals(1, parameterValues.size());
assertTrue(parameterValues.contains("value3"));
parameterValues = Arrays.asList((String[]) parameterMapValues.get(field4Index));
assertEquals(2, parameterValues.size());
assertTrue(parameterValues.contains("value4"));
assertTrue(parameterValues.contains("value5"));
parameterValues = Arrays.asList((String[]) parameterMapValues.get(getFieldIndex));
assertEquals(1, parameterValues.size());
assertTrue(parameterValues.contains("getValue"));
Set fileNames = new HashSet();
Iterator fileIter = request.getFileNames();
while (fileIter.hasNext()) {
fileNames.add(fileIter.next());
}
assertEquals(3, fileNames.size());
assertTrue(fileNames.contains("field1"));
assertTrue(fileNames.contains("field2"));
assertTrue(fileNames.contains("field2x"));
CommonsMultipartFile file1 = (CommonsMultipartFile) request.getFile("field1");
CommonsMultipartFile file2 = (CommonsMultipartFile) request.getFile("field2");
CommonsMultipartFile file2x = (CommonsMultipartFile) request.getFile("field2x");
Map fileMap = request.getFileMap();
assertEquals(3, fileMap.size());
assertTrue(fileMap.containsKey("field1"));
assertTrue(fileMap.containsKey("field2"));
assertTrue(fileMap.containsKey("field2x"));
assertEquals(file1, fileMap.get("field1"));
assertEquals(file2, fileMap.get("field2"));
assertEquals(file2x, fileMap.get("field2x"));
assertEquals("type1", file1.getContentType());
assertEquals("type2", file2.getContentType());
assertEquals("type2", file2x.getContentType());
assertEquals("field1.txt", file1.getOriginalFilename());
assertEquals("field2.txt", file2.getOriginalFilename());
assertEquals("field2x.txt", file2x.getOriginalFilename());
assertEquals("text1", new String(file1.getBytes()));
assertEquals("text2", new String(file2.getBytes()));
assertEquals(5, file1.getSize());
assertEquals(5, file2.getSize());
assertTrue(file1.getInputStream() instanceof ByteArrayInputStream);
assertTrue(file2.getInputStream() instanceof ByteArrayInputStream);
File transfer1 = new File("C:/transfer1");
File transfer2 = new File("C:/transfer2");
file1.transferTo(transfer1);
file2.transferTo(transfer2);
assertEquals(transfer1, ((MockFileItem) file1.getFileItem()).writtenFile);
assertEquals(transfer2, ((MockFileItem) file2.getFileItem()).writtenFile);
MultipartTestBean1 mtb1 = new MultipartTestBean1();
assertEquals(null, mtb1.getField1());
assertEquals(null, mtb1.getField2());
ServletRequestDataBinder binder = new ServletRequestDataBinder(mtb1, "mybean");
binder.registerCustomEditor(byte[].class, new ByteArrayMultipartFileEditor());
binder.bind(request);
assertEquals(file1, mtb1.getField1());
assertEquals(new String(file2.getBytes()), new String(mtb1.getField2()));
MultipartTestBean2 mtb2 = new MultipartTestBean2();
assertEquals(null, mtb2.getField1());
assertEquals(null, mtb2.getField2());
binder = new ServletRequestDataBinder(mtb2, "mybean");
binder.registerCustomEditor(String.class, "field1", new StringMultipartFileEditor());
binder.registerCustomEditor(String.class, "field2", new StringMultipartFileEditor("UTF-16"));
binder.bind(request);
assertEquals(new String(file1.getBytes()), mtb2.getField1());
assertEquals(new String(file2.getBytes(), "UTF-16"), mtb2.getField2());
resolver.cleanupMultipart(request);
assertTrue(((MockFileItem) file1.getFileItem()).deleted);
assertTrue(((MockFileItem) file2.getFileItem()).deleted);
resolver.setEmpty(true);
request = resolver.resolveMultipart(originalRequest);
binder.setBindEmptyMultipartFiles(false);
String firstBound = mtb2.getField1();
binder.bind(request);
assertTrue(mtb2.getField1().length() > 0);
assertEquals(firstBound, mtb2.getField1());
request = resolver.resolveMultipart(originalRequest);
binder.setBindEmptyMultipartFiles(true);
binder.bind(request);
assertTrue(mtb2.getField1().length() == 0);
}
public void testWithServletContextAndFilter() throws Exception {
StaticWebApplicationContext wac = new StaticWebApplicationContext();
wac.setServletContext(new MockServletContext());
wac.registerSingleton("filterMultipartResolver", MockCommonsMultipartResolver.class, new MutablePropertyValues());
wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp"));
wac.refresh();
wac.getServletContext().setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac);
CommonsMultipartResolver resolver = new CommonsMultipartResolver(wac.getServletContext());
assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp"));
MockFilterConfig filterConfig = new MockFilterConfig(wac.getServletContext(), "filter");
filterConfig.addInitParameter("class", "notWritable");
filterConfig.addInitParameter("unknownParam", "someValue");
final MultipartFilter filter = new MultipartFilter();
filter.init(filterConfig);
final List files = new ArrayList();
final FilterChain filterChain = new FilterChain() {
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
MultipartHttpServletRequest request = (MultipartHttpServletRequest) servletRequest;
files.addAll(request.getFileMap().values());
}
};
FilterChain filterChain2 = new PassThroughFilterChain(filter, filterChain);
MockHttpServletRequest originalRequest = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
originalRequest.setMethod("POST");
originalRequest.setContentType("multipart/form-data");
originalRequest.addHeader("Content-type", "multipart/form-data");
filter.doFilter(originalRequest, response, filterChain2);
CommonsMultipartFile file1 = (CommonsMultipartFile) files.get(0);
CommonsMultipartFile file2 = (CommonsMultipartFile) files.get(1);
assertTrue(((MockFileItem) file1.getFileItem()).deleted);
assertTrue(((MockFileItem) file2.getFileItem()).deleted);
}
public void testWithServletContextAndFilterWithCustomBeanName() throws Exception {
StaticWebApplicationContext wac = new StaticWebApplicationContext();
wac.setServletContext(new MockServletContext());
wac.refresh();
wac.registerSingleton("myMultipartResolver", MockCommonsMultipartResolver.class, new MutablePropertyValues());
wac.getServletContext().setAttribute(WebUtils.TEMP_DIR_CONTEXT_ATTRIBUTE, new File("mytemp"));
wac.getServletContext().setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac);
CommonsMultipartResolver resolver = new CommonsMultipartResolver(wac.getServletContext());
assertTrue(resolver.getFileItemFactory().getRepository().getAbsolutePath().endsWith("mytemp"));
MockFilterConfig filterConfig = new MockFilterConfig(wac.getServletContext(), "filter");
filterConfig.addInitParameter("multipartResolverBeanName", "myMultipartResolver");
final List files = new ArrayList();
FilterChain filterChain = new FilterChain() {
public void doFilter(ServletRequest originalRequest, ServletResponse response) {
if (originalRequest instanceof MultipartHttpServletRequest) {
MultipartHttpServletRequest request = (MultipartHttpServletRequest) originalRequest;
files.addAll(request.getFileMap().values());
}
}
};
MultipartFilter filter = new MultipartFilter() {
private boolean invoked = false;
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
super.doFilterInternal(request, response, filterChain);
super.doFilterInternal(request, response, filterChain);
if (invoked) {
throw new ServletException("Should not have been invoked twice");
}
invoked = true;
}
};
filter.init(filterConfig);
MockHttpServletRequest originalRequest = new MockHttpServletRequest();
originalRequest.setMethod("POST");
originalRequest.setContentType("multipart/form-data");
originalRequest.addHeader("Content-type", "multipart/form-data");
HttpServletResponse response = new MockHttpServletResponse();
filter.doFilter(originalRequest, response, filterChain);
CommonsMultipartFile file1 = (CommonsMultipartFile) files.get(0);
CommonsMultipartFile file2 = (CommonsMultipartFile) files.get(1);
assertTrue(((MockFileItem) file1.getFileItem()).deleted);
assertTrue(((MockFileItem) file2.getFileItem()).deleted);
}
public static class MockCommonsMultipartResolver extends CommonsMultipartResolver {
private boolean empty;
protected void setEmpty(boolean empty) {
this.empty = empty;
}
protected FileUpload newFileUpload(FileItemFactory fileItemFactory) {
return new ServletFileUpload() {
public List parseRequest(HttpServletRequest request) {
if (request instanceof MultipartHttpServletRequest) {
throw new IllegalStateException("Already a multipart request");
}
List fileItems = new ArrayList();
MockFileItem fileItem1 = new MockFileItem(
"field1", "type1", empty ? "" : "field1.txt", empty ? "" : "text1");
MockFileItem fileItem2 = new MockFileItem(
"field2", "type2", empty ? "" : "C:/field2.txt", empty ? "" : "text2");
MockFileItem fileItem2x = new MockFileItem(
"field2x", "type2", empty ? "" : "C:\\field2x.txt", empty ? "" : "text2");
MockFileItem fileItem3 = new MockFileItem("field3", null, null, "value3");
MockFileItem fileItem4 = new MockFileItem("field4", null, null, "value4");
MockFileItem fileItem5 = new MockFileItem("field4", null, null, "value5");
fileItems.add(fileItem1);
fileItems.add(fileItem2);
fileItems.add(fileItem2x);
fileItems.add(fileItem3);
fileItems.add(fileItem4);
fileItems.add(fileItem5);
return fileItems;
}
};
}
}
private static class MockFileItem implements FileItem {
private String fieldName;
private String contentType;
private String name;
private String value;
private File writtenFile;
private boolean deleted;
public MockFileItem(String fieldName, String contentType, String name, String value) {
this.fieldName = fieldName;
this.contentType = contentType;
this.name = name;
this.value = value;
}
public InputStream getInputStream() throws IOException {
return new ByteArrayInputStream(value.getBytes());
}
public String getContentType() {
return contentType;
}
public String getName() {
return name;
}
public boolean isInMemory() {
return true;
}
public long getSize() {
return value.length();
}
public byte[] get() {
return value.getBytes();
}
public String getString(String encoding) throws UnsupportedEncodingException {
return new String(get(), encoding);
}
public String getString() {
return value;
}
public void write(File file) throws Exception {
this.writtenFile = file;
}
public File getWrittenFile() {
return writtenFile;
}
public void delete() {
this.deleted = true;
}
public boolean isDeleted() {
return deleted;
}
public String getFieldName() {
return fieldName;
}
public void setFieldName(String s) {
this.fieldName = s;
}
public boolean isFormField() {
return (this.name == null);
}
public void setFormField(boolean b) {
throw new UnsupportedOperationException();
}
public OutputStream getOutputStream() throws IOException {
throw new UnsupportedOperationException();
}
}
public class MultipartTestBean1 {
private MultipartFile field1;
private byte[] field2;
public void setField1(MultipartFile field1) {
this.field1 = field1;
}
public MultipartFile getField1() {
return field1;
}
public void setField2(byte[] field2) {
this.field2 = field2;
}
public byte[] getField2() {
return field2;
}
}
public class MultipartTestBean2 {
private String field1;
private String field2;
public void setField1(String field1) {
this.field1 = field1;
}
public String getField1() {
return field1;
}
public void setField2(String field2) {
this.field2 = field2;
}
public String getField2() {
return field2;
}
}
}