/* * Copyright (c) 2010-2012 Sonatype, Inc. All rights reserved. * * This program is licensed to you under the Apache License Version 2.0, * and you may not use this file except in compliance with the Apache License Version 2.0. * You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, * software distributed under the Apache License Version 2.0 is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the Apache License Version 2.0 for the specific language governing permissions and limitations there under. */ package com.ning.http.multipart; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.List; import com.ning.http.client.RandomAccessBody; public class MultipartBody implements RandomAccessBody { private final byte[] boundary; private final long contentLength; private final List<com.ning.http.client.Part> parts; private final List<RandomAccessFile> files; private int startPart; ByteArrayInputStream currentStream; int currentStreamPosition; boolean endWritten; boolean doneWritingParts; FileLocation fileLocation; FilePart currentFilePart; FileChannel currentFileChannel; enum FileLocation { NONE, START, MIDDLE, END } public MultipartBody(final List<com.ning.http.client.Part> parts, final String boundary, final String contentLength) { this.boundary = MultipartEncodingUtil.getAsciiBytes(boundary .substring("multipart/form-data; boundary=".length())); this.contentLength = Long.parseLong(contentLength); this.parts = parts; files = new ArrayList<RandomAccessFile>(); startPart = 0; currentStreamPosition = -1; endWritten = false; doneWritingParts = false; fileLocation = FileLocation.NONE; currentFilePart = null; } @Override public void close() throws IOException { for (final RandomAccessFile file : files) { file.close(); } } @Override public long getContentLength() { return contentLength; } @Override public long read(final ByteBuffer buffer) throws IOException { try { int overallLength = 0; final int maxLength = buffer.capacity(); if (startPart == parts.size() && endWritten) { return overallLength; } boolean full = false; while (!full && !doneWritingParts) { com.ning.http.client.Part part = null; if (startPart < parts.size()) { part = parts.get(startPart); } if (currentFileChannel != null) { overallLength += currentFileChannel.read(buffer); if (currentFileChannel.position() == currentFileChannel .size()) { currentFileChannel.close(); currentFileChannel = null; } if (overallLength == maxLength) { full = true; } } else if (currentStreamPosition > -1) { overallLength += writeToBuffer(buffer, maxLength - overallLength); if (overallLength == maxLength) { full = true; } if (startPart == parts.size() && currentStream.available() == 0) { doneWritingParts = true; } } else if (part instanceof StringPart) { final StringPart currentPart = (StringPart) part; initializeStringPart(currentPart); startPart++; } else if (part instanceof com.ning.http.client.StringPart) { final StringPart currentPart = generateClientStringpart(part); initializeStringPart(currentPart); startPart++; } else if (part instanceof FilePart) { if (fileLocation == FileLocation.NONE) { currentFilePart = (FilePart) part; initializeFilePart(currentFilePart); } else if (fileLocation == FileLocation.START) { initializeFileBody(currentFilePart); } else if (fileLocation == FileLocation.MIDDLE) { initializeFileEnd(currentFilePart); } else if (fileLocation == FileLocation.END) { startPart++; if (startPart == parts.size() && currentStream.available() == 0) { doneWritingParts = true; } } } else if (part instanceof com.ning.http.client.FilePart) { if (fileLocation == FileLocation.NONE) { currentFilePart = generateClientFilePart(part); initializeFilePart(currentFilePart); } else if (fileLocation == FileLocation.START) { initializeFileBody(currentFilePart); } else if (fileLocation == FileLocation.MIDDLE) { initializeFileEnd(currentFilePart); } else if (fileLocation == FileLocation.END) { startPart++; if (startPart == parts.size() && currentStream.available() == 0) { doneWritingParts = true; } } } else if (part instanceof com.ning.http.client.ByteArrayPart) { final com.ning.http.client.ByteArrayPart bytePart = (com.ning.http.client.ByteArrayPart) part; if (fileLocation == FileLocation.NONE) { currentFilePart = generateClientByteArrayPart(bytePart); initializeFilePart(currentFilePart); } else if (fileLocation == FileLocation.START) { initializeByteArrayBody(currentFilePart); } else if (fileLocation == FileLocation.MIDDLE) { initializeFileEnd(currentFilePart); } else if (fileLocation == FileLocation.END) { startPart++; if (startPart == parts.size() && currentStream.available() == 0) { doneWritingParts = true; } } } } if (doneWritingParts) { if (currentStreamPosition == -1) { final ByteArrayOutputStream endWriter = new ByteArrayOutputStream(); Part.sendMessageEnd(endWriter, boundary); initializeBuffer(endWriter); } if (currentStreamPosition > -1) { overallLength += writeToBuffer(buffer, maxLength - overallLength); if (currentStream.available() == 0) { currentStream.close(); currentStreamPosition = -1; endWritten = true; } } } return overallLength; } catch (final Exception e) { return 0; } } private void initializeByteArrayBody(final FilePart filePart) throws IOException { final ByteArrayOutputStream output = generateByteArrayBody(filePart); initializeBuffer(output); fileLocation = FileLocation.MIDDLE; } private void initializeFileEnd(final FilePart currentPart) throws IOException { final ByteArrayOutputStream output = generateFileEnd(currentPart); initializeBuffer(output); fileLocation = FileLocation.END; } private void initializeFileBody(final FilePart currentPart) throws IOException { if (FilePartSource.class.isAssignableFrom(currentPart.getSource() .getClass())) { final FilePartSource source = (FilePartSource) currentPart .getSource(); final File file = source.getFile(); final RandomAccessFile raf = new RandomAccessFile(file, "r"); files.add(raf); currentFileChannel = raf.getChannel(); } else { final PartSource partSource = currentPart.getSource(); final InputStream stream = partSource.createInputStream(); final byte[] bytes = new byte[(int) partSource.getLength()]; stream.read(bytes); currentStream = new ByteArrayInputStream(bytes); currentStreamPosition = 0; } fileLocation = FileLocation.MIDDLE; } private void initializeFilePart(final FilePart filePart) throws IOException { filePart.setPartBoundary(boundary); final ByteArrayOutputStream output = generateFileStart(filePart); initializeBuffer(output); fileLocation = FileLocation.START; } private void initializeStringPart(final StringPart currentPart) throws IOException { currentPart.setPartBoundary(boundary); final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); Part.sendPart(outputStream, currentPart, boundary); initializeBuffer(outputStream); } private int writeToBuffer(final ByteBuffer buffer, final int length) throws IOException { final int available = currentStream.available(); final int writeLength = Math.min(available, length); final byte[] bytes = new byte[writeLength]; currentStream.read(bytes); buffer.put(bytes); if (available <= length) { currentStream.close(); currentStreamPosition = -1; } else { currentStreamPosition += writeLength; } return writeLength; } private void initializeBuffer(final ByteArrayOutputStream outputStream) throws IOException { currentStream = new ByteArrayInputStream(outputStream.toByteArray()); currentStreamPosition = 0; } @Override public long transferTo(final long position, final long count, final WritableByteChannel target) throws IOException { long overallLength = 0; if (startPart == parts.size()) { return contentLength; } int tempPart = startPart; for (final com.ning.http.client.Part part : parts) { if (part instanceof Part) { overallLength += handleMultiPart(target, (Part) part); } else { overallLength += handleClientPart(target, part); } tempPart++; } final ByteArrayOutputStream endWriter = new ByteArrayOutputStream(); Part.sendMessageEnd(endWriter, boundary); overallLength += writeToTarget(target, endWriter); startPart = tempPart; return overallLength; } private long handleClientPart(final WritableByteChannel target, final com.ning.http.client.Part part) throws IOException { if (part.getClass().equals(com.ning.http.client.StringPart.class)) { final StringPart currentPart = generateClientStringpart(part); return handleStringPart(target, currentPart); } else if (part.getClass().equals(com.ning.http.client.FilePart.class)) { final FilePart filePart = generateClientFilePart(part); return handleFilePart(target, filePart); } else if (part.getClass().equals( com.ning.http.client.ByteArrayPart.class)) { final com.ning.http.client.ByteArrayPart bytePart = (com.ning.http.client.ByteArrayPart) part; final FilePart filePart = generateClientByteArrayPart(bytePart); return handleByteArrayPart(target, filePart, bytePart.getData()); } return 0; } private FilePart generateClientByteArrayPart( final com.ning.http.client.ByteArrayPart bytePart) { final ByteArrayPartSource source = new ByteArrayPartSource( bytePart.getFileName(), bytePart.getData()); final FilePart filePart = new FilePart(bytePart.getName(), source, bytePart.getMimeType(), bytePart.getCharSet()); return filePart; } private FilePart generateClientFilePart(final com.ning.http.client.Part part) throws FileNotFoundException { final com.ning.http.client.FilePart currentPart = (com.ning.http.client.FilePart) part; final FilePart filePart = new FilePart(currentPart.getName(), currentPart.getFile(), currentPart.getMimeType(), currentPart.getCharSet()); return filePart; } private StringPart generateClientStringpart( final com.ning.http.client.Part part) { final com.ning.http.client.StringPart stringPart = (com.ning.http.client.StringPart) part; final StringPart currentPart = new StringPart(stringPart.getName(), stringPart.getValue(), stringPart.getCharset()); return currentPart; } private long handleByteArrayPart(final WritableByteChannel target, final FilePart filePart, final byte[] data) throws IOException { final ByteArrayOutputStream output = generateByteArrayBody(filePart); return writeToTarget(target, output); } private ByteArrayOutputStream generateByteArrayBody(final FilePart filePart) throws IOException { final ByteArrayOutputStream output = new ByteArrayOutputStream(); Part.sendPart(output, filePart, boundary); return output; } private long handleFileEnd(final WritableByteChannel target, final FilePart filePart) throws IOException { final ByteArrayOutputStream endOverhead = generateFileEnd(filePart); return this.writeToTarget(target, endOverhead); } private ByteArrayOutputStream generateFileEnd(final FilePart filePart) throws IOException { final ByteArrayOutputStream endOverhead = new ByteArrayOutputStream(); filePart.sendEnd(endOverhead); return endOverhead; } private long handleFileHeaders(final WritableByteChannel target, final FilePart filePart) throws IOException { filePart.setPartBoundary(boundary); final ByteArrayOutputStream overhead = generateFileStart(filePart); return writeToTarget(target, overhead); } private ByteArrayOutputStream generateFileStart(final FilePart filePart) throws IOException { final ByteArrayOutputStream overhead = new ByteArrayOutputStream(); filePart.setPartBoundary(boundary); filePart.sendStart(overhead); filePart.sendDispositionHeader(overhead); filePart.sendContentTypeHeader(overhead); filePart.sendTransferEncodingHeader(overhead); filePart.sendEndOfHeader(overhead); return overhead; } private long handleFilePart(final WritableByteChannel target, final FilePart filePart) throws IOException { final FilePartStallHandler handler = new FilePartStallHandler( filePart.getStalledTime(), filePart); handler.start(); if (FilePartSource.class.isAssignableFrom(filePart.getSource() .getClass())) { int length = 0; length += handleFileHeaders(target, filePart); final FilePartSource source = (FilePartSource) filePart.getSource(); final File file = source.getFile(); final RandomAccessFile raf = new RandomAccessFile(file, "r"); files.add(raf); final FileChannel fc = raf.getChannel(); final long l = file.length(); int fileLength = 0; long nWrite = 0; synchronized (fc) { while (fileLength != l) { if (handler.isFailed()) { throw new FileUploadStalledException(); } try { nWrite = fc.transferTo(fileLength, l, target); if (nWrite == 0) { try { fc.wait(50); } catch (final InterruptedException e) { } } else { handler.writeHappened(); } } catch (final IOException ex) { final String message = ex.getMessage(); // http://bugs.sun.com/view_bug.do?bug_id=5103988 if (message != null && message .equalsIgnoreCase("Resource temporarily unavailable")) { try { fc.wait(1000); } catch (final InterruptedException e) { } continue; } else { throw ex; } } fileLength += nWrite; } } handler.completed(); fc.close(); length += handleFileEnd(target, filePart); return length; } else { return handlePartSource(target, filePart); } } private long handlePartSource(final WritableByteChannel target, final FilePart filePart) throws IOException { int length = 0; length += handleFileHeaders(target, filePart); final PartSource partSource = filePart.getSource(); final InputStream stream = partSource.createInputStream(); try { int nRead = 0; while (nRead != -1) { // Do not buffer the entire monster in memory. final byte[] bytes = new byte[8192]; nRead = stream.read(bytes); if (nRead > 0) { final ByteArrayOutputStream bos = new ByteArrayOutputStream( nRead); bos.write(bytes, 0, nRead); writeToTarget(target, bos); } } } finally { stream.close(); } length += handleFileEnd(target, filePart); return length; } private long handleStringPart(final WritableByteChannel target, final StringPart currentPart) throws IOException { currentPart.setPartBoundary(boundary); final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); Part.sendPart(outputStream, currentPart, boundary); return writeToTarget(target, outputStream); } private long handleMultiPart(final WritableByteChannel target, final Part currentPart) throws IOException { currentPart.setPartBoundary(boundary); if (currentPart.getClass().equals(StringPart.class)) { return handleStringPart(target, (StringPart) currentPart); } else if (currentPart.getClass().equals(FilePart.class)) { final FilePart filePart = (FilePart) currentPart; return handleFilePart(target, filePart); } return 0; } private long writeToTarget(final WritableByteChannel target, final ByteArrayOutputStream byteWriter) throws IOException { int written = 0; int maxSpin = 0; synchronized (byteWriter) { final ByteBuffer message = ByteBuffer .wrap(byteWriter.toByteArray()); while ((target.isOpen()) && (written < byteWriter.size())) { final long nWrite = target.write(message); written += nWrite; if (nWrite == 0 && maxSpin++ < 10) { try { byteWriter.wait(1000); } catch (final InterruptedException e) { } } else { if (maxSpin >= 10) { throw new IOException("Unable to write on channel " + target); } maxSpin = 0; } } } return written; } }