package info.guardianproject.otr; import info.guardianproject.otr.app.im.IDataListener; import info.guardianproject.otr.app.im.app.ImApp; import info.guardianproject.otr.app.im.engine.Address; import info.guardianproject.otr.app.im.engine.ChatSession; import info.guardianproject.otr.app.im.engine.DataHandler; import info.guardianproject.otr.app.im.engine.Message; import info.guardianproject.util.Debug; import info.guardianproject.util.LogCleaner; import info.guardianproject.util.SystemServices; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.UUID; import org.apache.http.HttpException; import org.apache.http.HttpMessage; import org.apache.http.HttpRequest; import org.apache.http.HttpRequestFactory; import org.apache.http.HttpResponse; import org.apache.http.HttpResponseFactory; import org.apache.http.MethodNotSupportedException; import org.apache.http.ProtocolVersion; import org.apache.http.RequestLine; import org.apache.http.impl.DefaultHttpResponseFactory; import org.apache.http.impl.io.AbstractSessionInputBuffer; import org.apache.http.impl.io.AbstractSessionOutputBuffer; import org.apache.http.impl.io.HttpRequestParser; import org.apache.http.impl.io.HttpRequestWriter; import org.apache.http.impl.io.HttpResponseParser; import org.apache.http.impl.io.HttpResponseWriter; import org.apache.http.io.HttpMessageWriter; import org.apache.http.io.SessionInputBuffer; import org.apache.http.message.BasicHttpRequest; import org.apache.http.message.BasicHttpResponse; import org.apache.http.message.BasicLineFormatter; import org.apache.http.message.BasicLineParser; import org.apache.http.message.BasicStatusLine; import org.apache.http.message.LineFormatter; import org.apache.http.message.LineParser; import org.apache.http.params.BasicHttpParams; import org.apache.http.params.HttpParams; import android.os.Environment; import android.os.RemoteException; import android.util.Log; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.Maps; import com.google.common.collect.Sets; public class OtrDataHandler implements DataHandler { public static final String URI_PREFIX_OTR_IN_BAND = "otr-in-band:/storage/"; private static final int MAX_OUTSTANDING = 3; private static final int MAX_CHUNK_LENGTH = 32768; private static final int MAX_TRANSFER_LENGTH = 1024*1024*64; private static final byte[] EMPTY_BODY = new byte[0]; private static final String TAG = "GB.OtrDataHandler"; private static final ProtocolVersion PROTOCOL_VERSION = new ProtocolVersion("HTTP", 1, 1); private static HttpParams params = new BasicHttpParams(); private static HttpRequestFactory requestFactory = new MyHttpRequestFactory(); private static HttpResponseFactory responseFactory = new DefaultHttpResponseFactory(); private LineParser lineParser = new BasicLineParser(PROTOCOL_VERSION); private LineFormatter lineFormatter = new BasicLineFormatter(); private ChatSession mChatSession; private IDataListener mDataListener; public OtrDataHandler(ChatSession chatSession) { this.mChatSession = chatSession; } public void setDataListener (IDataListener dataListener) { mDataListener = dataListener; } public static class MyHttpRequestFactory implements HttpRequestFactory { public MyHttpRequestFactory() { super(); } public HttpRequest newHttpRequest(final RequestLine requestline) throws MethodNotSupportedException { if (requestline == null) { throw new IllegalArgumentException("Request line may not be null"); } //String method = requestline.getMethod(); return new BasicHttpRequest(requestline); } public HttpRequest newHttpRequest(final String method, final String uri) throws MethodNotSupportedException { return new BasicHttpRequest(method, uri); } } static class MemorySessionInputBuffer extends AbstractSessionInputBuffer { public MemorySessionInputBuffer(byte[] value) { init(new ByteArrayInputStream(value), 1000, params); } @Override public boolean isDataAvailable(int timeout) throws IOException { throw new UnsupportedOperationException(); } } static class MemorySessionOutputBuffer extends AbstractSessionOutputBuffer { ByteArrayOutputStream outputStream; public MemorySessionOutputBuffer() { outputStream = new ByteArrayOutputStream(1000); init(outputStream, 1000, params); } public byte[] getOutput() { return outputStream.toByteArray(); } } public void onIncomingRequest(Address requestThem, Address requestUs, byte[] value) { SessionInputBuffer inBuf = new MemorySessionInputBuffer(value); HttpRequestParser parser = new HttpRequestParser(inBuf, lineParser, requestFactory, params); HttpRequest req; try { req = (HttpRequest)parser.parse(); } catch (IOException e) { throw new RuntimeException(e); } catch (HttpException e) { e.printStackTrace(); return; } String requestMethod = req.getRequestLine().getMethod(); String uid = req.getFirstHeader("Request-Id").getValue(); String url = req.getRequestLine().getUri(); if (requestMethod.equals("OFFER")) { debug("incoming OFFER " + url); if (!url.startsWith(URI_PREFIX_OTR_IN_BAND)) { debug("Unknown url scheme " + url); sendResponse(requestUs, 400, "Unknown scheme", uid, EMPTY_BODY); return; } sendResponse(requestUs, 200, "OK", uid, EMPTY_BODY); if (!req.containsHeader("File-Length")) { sendResponse(requestUs, 400, "File-Length must be supplied", uid, EMPTY_BODY); return; } int length = Integer.parseInt(req.getFirstHeader("File-Length").getValue()); if (!req.containsHeader("File-Hash-SHA1")) { sendResponse(requestUs, 400, "File-Hash-SHA1 must be supplied", uid, EMPTY_BODY); return; } String sum = req.getFirstHeader("File-Hash-SHA1").getValue(); String type = null; if (req.containsHeader("Mime-Type")) { type = req.getFirstHeader("Mime-Type").getValue(); } debug("Incoming sha1sum " + sum); Transfer transfer = new Transfer(url, type, length, requestUs, sum); transferCache.put(url, transfer); // Handle offer // TODO ask user to confirm we want this boolean accept = false; if (mDataListener != null) { try { accept = mDataListener.onTransferRequested(requestThem.getAddress(),requestUs.getAddress(),transfer.url); if (accept) transfer.perform(); } catch (RemoteException e) { LogCleaner.error(ImApp.LOG_TAG, "error approving OTRDATA transfer request", e); } } } else if (requestMethod.equals("GET") && url.startsWith(URI_PREFIX_OTR_IN_BAND)) { debug("incoming GET " + url); ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); int reqEnd; try { Offer offer = offerCache.getIfPresent(url); if (offer == null) { sendResponse(requestUs, 400, "No such offer made", uid, EMPTY_BODY); return; } if (!req.containsHeader("Range")) { sendResponse(requestUs, 400, "Range must start with bytes=", uid, EMPTY_BODY); return; } String rangeHeader = req.getFirstHeader("Range").getValue(); String[] spec = rangeHeader.split("="); if (spec.length != 2 || !spec[0].equals("bytes")) { sendResponse(requestUs, 400, "Range must start with bytes=", uid, EMPTY_BODY); return; } String[] startEnd = spec[1].split("-"); if (startEnd.length != 2) { sendResponse(requestUs, 400, "Range must be START-END", uid, EMPTY_BODY); return; } int start = Integer.parseInt(startEnd[0]); int end = Integer.parseInt(startEnd[1]); if (end - start + 1 > MAX_CHUNK_LENGTH) { sendResponse(requestUs, 400, "Range must be at most " + MAX_CHUNK_LENGTH, uid, EMPTY_BODY); return; } File fileGet = new File(offer.getUri()); FileInputStream is = new FileInputStream(fileGet); readIntoByteBuffer(byteBuffer, is, start, end); if (mDataListener != null) { float percent = ((float)end) / ((float)fileGet.length()); if (percent < .98f) { mDataListener.onTransferProgress(requestThem.getAddress(), offer.getUri(), percent); } else { String mimeType = null; if (req.getFirstHeader("Mime-Type") != null) mimeType = req.getFirstHeader("Mime-Type").getValue(); mDataListener.onTransferComplete(requestThem.getAddress(), offer.getUri(), mimeType, offer.getUri()); } } } catch (UnsupportedEncodingException e) { // throw new RuntimeException(e); sendResponse(requestUs, 400, "Unsupported encoding", uid, EMPTY_BODY); return; } catch (IOException e) { //throw new RuntimeException(e); sendResponse(requestUs, 400, "IOException", uid, EMPTY_BODY); return; } catch (NumberFormatException e) { sendResponse(requestUs, 400, "Range is not numeric", uid, EMPTY_BODY); return; } catch (Exception e) { sendResponse(requestUs, 500, "Unknown error", uid, EMPTY_BODY); return; } byte[] body = byteBuffer.toByteArray(); debug("Sent sha1 is " + sha1sum(body)); sendResponse(requestUs, 200, "OK", uid, body); } else { debug("Unknown method / url " + requestMethod + " " + url); sendResponse(requestUs, 400, "OK", uid, EMPTY_BODY); } } private void readIntoByteBuffer(ByteArrayOutputStream byteBuffer, FileInputStream is, int start, int end) throws IOException { if (start != is.skip(start)) { return; } int size = end - start + 1; int buffersize = 1024; byte[] buffer = new byte[buffersize]; int len = 0; while((len = is.read(buffer)) != -1){ if (len > size) { len = size; } byteBuffer.write(buffer, 0, len); size -= len; } } private void readIntoByteBuffer(ByteArrayOutputStream byteBuffer, SessionInputBuffer sib) throws IOException { int buffersize = 1024; byte[] buffer = new byte[buffersize]; int len = 0; while((len = sib.read(buffer)) != -1){ byteBuffer.write(buffer, 0, len); } } private void sendResponse(Address us, int code, String statusString, String uid, byte[] body) { MemorySessionOutputBuffer outBuf = new MemorySessionOutputBuffer(); HttpMessageWriter writer = new HttpResponseWriter(outBuf, lineFormatter, params); HttpMessage response = new BasicHttpResponse(new BasicStatusLine(PROTOCOL_VERSION, code, statusString)); response.addHeader("Request-Id", uid); try { writer.write(response); outBuf.write(body); outBuf.flush(); } catch (IOException e) { throw new RuntimeException(e); } catch (HttpException e) { throw new RuntimeException(e); } byte[] data = outBuf.getOutput(); Message message = new Message(""); message.setFrom(us); debug("send response"); mChatSession.sendDataAsync(message, true, data); } public void onIncomingResponse(Address from, Address to, byte[] value) { SessionInputBuffer buffer = new MemorySessionInputBuffer(value); HttpResponseParser parser = new HttpResponseParser(buffer, lineParser, responseFactory, params); HttpResponse res; try { res = (HttpResponse) parser.parse(); } catch (IOException e) { throw new RuntimeException(e); } catch (HttpException e) { e.printStackTrace(); return; } String uid = res.getFirstHeader("Request-Id").getValue(); Request request = requestCache.getIfPresent(uid); if (request == null) { debug("Unknown request ID " + uid); return; } if (request.isSeen()) { debug("Already seen request ID " + uid); return; } request.seen(); int statusCode = res.getStatusLine().getStatusCode(); if (statusCode != 200) { debug("got status " + statusCode + ": " + res.getStatusLine().getReasonPhrase()); // TODO handle error return; } // TODO handle success try { ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); readIntoByteBuffer(byteBuffer, buffer); debug("Received sha1 @" + request.start + " is " + sha1sum(byteBuffer.toByteArray())); if (request.method.equals("GET")) { Transfer transfer = transferCache.getIfPresent(request.url); if (transfer == null) { debug("Transfer expired for url " + request.url); return; } transfer.chunkReceived(request, byteBuffer.toByteArray()); if (transfer.isDone()) { byte[] data = transfer.getData(); debug("Transfer complete for " + request.url); if (transfer.checkSum()) { debug("Received file len=" + data.length + " sha1=" + sha1sum(data)); File fileShare = writeDataToStorage(transfer.url, data); if (mDataListener != null) mDataListener.onTransferComplete( mChatSession.getParticipant().getAddress().getAddress(), transfer.url, transfer.type, fileShare.getCanonicalPath()); } else { if (mDataListener != null) mDataListener.onTransferFailed( mChatSession.getParticipant().getAddress().getAddress(), transfer.url, "checksum"); Log.e(TAG, "Wrong checksum for file len= " + data.length + " sha1=" + sha1sum(data)); } } else { if (mDataListener != null) mDataListener.onTransferProgress(mChatSession.getParticipant().getAddress().getAddress(), transfer.url, ((float)transfer.chunksReceived) / transfer.chunks); transfer.perform(); debug("Progress " + transfer.chunksReceived + " / " + transfer.chunks); } } } catch (IOException e) { debug("Could not read line from response"); } catch (RemoteException e) { debug("Could not read remote exception"); } } private File writeDataToStorage (String url, byte[] data) { //String nickname = getNickName(username); File sdCard = Environment.getExternalStorageDirectory(); String[] path = url.split("/"); //String sanitizedPeer = SystemServices.sanitize(username); String sanitizedPath = SystemServices.sanitize(path[path.length - 1]); File fileDownloadsDir = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS); fileDownloadsDir.mkdirs(); File file = new File(fileDownloadsDir, sanitizedPath); try { OutputStream output = (new FileOutputStream(file)); output.write(data); output.close(); return file; } catch (IOException e) { OtrDebugLogger.log("error writing file", e); return null; } } /** * @param headers may be null */ @Override public void offerData(Address us, String localUri, Map<String, String> headers) { // TODO stash localUri and intended recipient long length = new File(localUri).length(); if (length > MAX_TRANSFER_LENGTH) { throw new RuntimeException("Length too large " + length); } if (headers == null) headers = Maps.newHashMap(); headers.put("File-Length", String.valueOf(length)); ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(); try { FileInputStream is = new FileInputStream(localUri); readIntoByteBuffer(byteBuffer, is, 0, (int)(length - 1)); } catch (IOException e) { throw new RuntimeException(e); } headers.put("File-Hash-SHA1", sha1sum(byteBuffer.toByteArray())); String[] paths = localUri.split("/"); String url = URI_PREFIX_OTR_IN_BAND + SystemServices.sanitize(paths[paths.length - 1]); Request request = new Request("OFFER", url); offerCache.put(url, new Offer(localUri)); sendRequest(us, "OFFER", url, headers, EMPTY_BODY, request); } public Request performGetData(Address us, String url, Map<String, String> headers, int start, int end) { String rangeSpec = "bytes=" + start + "-" + end; debug("Getting range " + rangeSpec); headers.put("Range", rangeSpec); Request requestMemo = new Request("GET", url, start, end); sendRequest(us, "GET", url, headers, EMPTY_BODY, requestMemo); return requestMemo; } static class Offer { private String mUri; public Offer(String uri) { this.mUri = uri; } public String getUri() { return mUri; } } static class Request { public Request(String method, String url, int start, int end) { this.method = method; this.url = url; this.start = start; this.end = end; } public Request(String method, String url) { this(method, url, -1, -1); } public String method; public String url; public int start; public int end; public byte[] data; public boolean seen = false; public boolean isSeen() { return seen; } public void seen() { seen = true; } } public class Transfer { public String url; public String type; public int chunks = 0; public int chunksReceived = 0; private int length = 0; private int current = 0; private Address us; private Set<Request> outstanding; private byte[] buffer; private String sum; public Transfer(String url, String type, int length, Address us, String sum) { this.url = url; this.type = type; this.length = length; this.us = us; this.sum = sum; if (length > MAX_TRANSFER_LENGTH || length <= 0) { throw new RuntimeException("Invalid transfer size " + length); } chunks = ((length - 1) / MAX_CHUNK_LENGTH) + 1; buffer = new byte[length]; outstanding = Sets.newHashSet(); } public boolean checkSum() { return sum.equals(sha1sum(buffer)); } public boolean perform() { // TODO global throttle rather than this local hack while (outstanding.size() < MAX_OUTSTANDING) { if (current >= length) return false; int end = current + MAX_CHUNK_LENGTH - 1; if (end >= length) { end = length - 1; } Map<String, String> headers = Maps.newHashMap(); Request request= performGetData(us, url, headers, current, end); outstanding.add(request); current = end + 1; } return true; } public byte[] getData() { // TODO Auto-generated method stub return buffer; } public boolean isDone() { return chunksReceived == chunks; } public void chunkReceived(Request request, byte[] bs) { chunksReceived++; System.arraycopy(bs, 0, buffer, request.start, bs.length); outstanding.remove(request); } public String getSum() { return sum; } } Cache<String, Offer> offerCache = CacheBuilder.newBuilder().maximumSize(100).build(); Cache<String, Request> requestCache = CacheBuilder.newBuilder().maximumSize(100).build(); Cache<String, Transfer> transferCache = CacheBuilder.newBuilder().maximumSize(100).build(); private void sendRequest(Address us, String method, String url, Map<String, String> headers, byte[] body, Request requestMemo) { MemorySessionOutputBuffer outBuf = new MemorySessionOutputBuffer(); HttpMessageWriter writer = new HttpRequestWriter(outBuf, lineFormatter, params); HttpMessage req = new BasicHttpRequest(method, url, PROTOCOL_VERSION); String uid = UUID.randomUUID().toString(); req.addHeader("Request-Id", uid); if (headers != null) { for (Entry<String, String> entry : headers.entrySet()) { req.addHeader(entry.getKey(), entry.getValue()); } } try { writer.write(req); outBuf.write(body); outBuf.flush(); } catch (IOException e) { throw new RuntimeException(e); } catch (HttpException e) { throw new RuntimeException(e); } byte[] data = outBuf.getOutput(); Message message = new Message(""); message.setFrom(us); debug("send request " + method + " " + url); requestCache.put(uid, requestMemo); mChatSession.sendDataAsync(message, false, data); } private static String hexChr(int b) { return Integer.toHexString(b & 0xF); } private static String toHex(int b) { return hexChr((b & 0xF0) >> 4) + hexChr(b & 0x0F); } private String sha1sum(byte[] bytes) { MessageDigest digest; try { digest = MessageDigest.getInstance("SHA1"); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } digest.update(bytes, 0, bytes.length); byte[] sha1sum = digest.digest(); String display = ""; for(byte b : sha1sum) display += toHex(b); return display; } private void debug (String msg) { if (Debug.DEBUG_ENABLED) Log.d(ImApp.LOG_TAG,msg); } }