// Copyright (C) 2012 jOVAL.org. All rights reserved. // This software is licensed under the AGPL 3.0 license available at http://www.joval.org/agpl_v3.txt package jwsmv.http; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.IOException; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.InetSocketAddress; import java.net.ProtocolException; import java.net.Proxy; import java.net.Socket; import java.net.URL; import java.nio.charset.Charset; import java.security.Permission; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.StringTokenizer; import javax.net.ssl.SSLSocketFactory; import jwsmv.util.RFC822; /** * An HTTP 1.1 connection implementation that re-uses a single socket connection. This is useful when a single TCP connection * is needed to communicate repeatedly with a particular URL, for example, when performing NTLM authentication negotiation. * * Thanks to James Marshall for his concise discussion of HTTP/1.1: * @see http://www.jmarshall.com/easy/http/#http1.1c2 * * @author David A. Solin * @version %I% %G% */ public class HttpSocketConnection extends AbstractConnection { private static int defaultChunkLength = 512; private boolean secure, tunnelFailure; private Socket socket; private Proxy proxy; private String host; // host:port private boolean gotResponse; private HSOutputStream stream; /** * Create a direct connection. */ public HttpSocketConnection(URL url) { this(url, null); } /** * Create a connection through a proxy. */ public HttpSocketConnection(URL url, Proxy proxy) throws IllegalArgumentException { super(url); if (url.getProtocol().equalsIgnoreCase("HTTPS")) { secure = true; } else if (url.getProtocol().equalsIgnoreCase("HTTP")) { secure = false; } else { throw new IllegalArgumentException("Unsupported protocol: " + url.getProtocol()); } setProxy(proxy); StringBuffer sb = new StringBuffer(url.getHost()); if (url.getPort() == -1) { if (secure) { sb.append(":443"); } else { sb.append(":80"); } } else { sb.append(":").append(Integer.toString(url.getPort())); } host = sb.toString(); reset(); } // Overrides for HttpURLConnection @Override public Permission getPermission() throws IOException { return new java.net.SocketPermission(host, "connect"); } @Override public boolean usingProxy() { return proxy == null ? false : proxy.type() != Proxy.Type.DIRECT; } /** * If a fixed content length has not been set, this method causes the connection to use chunked encoding. */ @Override public OutputStream getOutputStream() throws IOException { if (doOutput) { if (stream == null) { switch(fixedContentLength) { case -1: stream = new HSChunkedOutputStream(chunkLength); break; default: stream = new HSOutputStream(fixedContentLength); break; } } connect(); return stream; } else { throw new IllegalStateException("Output not allowed"); } } @Override public void setRequestProperty(String key, String value) { if (connected) { throw new IllegalStateException("Already connected"); } setMapProperty(key, value, requestProperties); } @Override public void addRequestProperty(String key, String value) { if (connected) { throw new IllegalStateException("Already connected"); } addMapProperty(key, value, requestProperties); } @Override public String getRequestProperty(String key) { for (Map.Entry<String, List<String>> entry : requestProperties.entrySet()) { if (key.equalsIgnoreCase(entry.getKey())) { return entry.getValue().get(0); } } return null; } @Override public Map<String, List<String>> getRequestProperties() { Map<String, List<String>> map = new HashMap<String, List<String>>(); for (Map.Entry<String, List<String>> entry : requestProperties.entrySet()) { map.put(entry.getKey(), Collections.unmodifiableList(entry.getValue())); } return Collections.unmodifiableMap(map); } @Override public void disconnect() { if (connected) { try { socket.close(); } catch (IOException e) { } connected = false; } } @Override public void connect() throws IOException { if (connected) { return; } if (socket == null || socket.isClosed()) { switch(proxy.type()) { case SOCKS: socket = new Socket(proxy); socket.connect(new InetSocketAddress(url.getHost(), url.getPort())); break; case HTTP: socket = new Socket(); socket.connect(proxy.address()); break; case DIRECT: default: socket = new Socket(url.getHost(), url.getPort()); break; } socket.setSoTimeout(TIMEOUT); } if (secure && proxy.type() == Proxy.Type.HTTP) { // // Establish a tunnel through the proxy // write(new StringBuffer("CONNECT ").append(host).append(" HTTP/1.1").toString()); write(CRLF); String temp = getRequestProperty("Proxy-Authorization"); if (temp != null) { write(new KVP("Proxy-Authorization", temp)); } temp = getRequestProperty("User-Agent"); if (temp != null) { write(new KVP("User-Agent", temp)); } if (method.equalsIgnoreCase("GET") && ifModifiedSince > 0) { write(new KVP("If-Modified-Since", RFC822.toString(ifModifiedSince))); } write(new KVP("Connection", "Keep-Alive")); write(CRLF); InputStream in = socket.getInputStream(); Map<String, List<String>> map = new HashMap<String, List<String>>(); KVP pair = null; while((pair = readKVP(in)) != null) { if (pair.key().length() == 0) { parseResponse(pair.value()); } if (responseCode != HTTP_OK) { if (orderedHeaderFields.size() > 0) { addMapProperties(pair, map); } orderedHeaderFields.add(pair); } } if (responseCode == HTTP_OK) { // // Establish a socket tunnel // int port = Integer.parseInt(host.substring(host.indexOf(":") + 1)); socket = ((SSLSocketFactory)SSLSocketFactory.getDefault()).createSocket(socket, url.getHost(), port, true); } else { stream = new HSDevNull(); headerFields = Collections.unmodifiableMap(map); gotResponse = true; } } else { StringBuffer req = new StringBuffer(getRequestMethod()).append(" "); switch(proxy.type()) { case SOCKS: case DIRECT: String path = url.getPath(); if (!path.startsWith("/")) { req.append("/"); } req.append(path); break; case HTTP: req.append(url.toString()); break; } req.append(" HTTP/1.1"); setRequestProperty("Connection", "Keep-Alive"); setRequestProperty("Host", host); if (doOutput) { if (fixedContentLength != -1) { setRequestProperty("Content-Length", Integer.toString(fixedContentLength)); } else { if (chunkLength == -1) { chunkLength = defaultChunkLength; } setRequestProperty("Transfer-Encoding", "chunked"); } } write(req.toString()); write(CRLF); for (Map.Entry<String, List<String>> entry : requestProperties.entrySet()) { KVP pair = new KVP(entry); write(pair); } write(CRLF); } connected = true; } // Internal /** * Check to determine whether the socket is connected. */ boolean connected() { return socket != null && socket.isConnected(); } /** * Set a proxy. */ void setProxy(Proxy proxy) { if (proxy == null) { proxy = Proxy.NO_PROXY; } if (connected()) { if (!proxy.equals(this.proxy)) { throw new IllegalStateException("Cannot change proxies when connected"); } } else { if (proxy == null) { this.proxy = Proxy.NO_PROXY; } else { this.proxy = proxy; } } } /** * Reset the connection to a pristine state. */ void reset() { initialize(); setRequestProperty("User-Agent", "jWSMV HTTP Client"); if (stream != null) { try { stream.close(); } catch (IOException e) { disconnect(); } stream = null; } responseData = null; gotResponse = false; } /** * Read the response over the socket. */ @Override void getResponse() throws IOException { if (gotResponse) return; connect(); try { if (stream == null) { switch(fixedContentLength) { case -1: // // connect() would have assumed chunked transfer-encoding, so write a final 0-length chunk. // write("0"); write(CRLF); write(CRLF); // fall-thru case 0: break; default: throw new IllegalStateException("You promised to write " + fixedContentLength + " bytes!"); } } else if (!stream.complete()) { throw new IllegalStateException("You must write " + stream.remaining() + " more bytes!"); } else { stream.close(); } orderedHeaderFields = new ArrayList<KVP>(); Map<String, List<String>> map = new HashMap<String, List<String>>(); boolean chunked = false; InputStream in = socket.getInputStream(); KVP pair = null; while((pair = readKVP(in)) != null) { if (orderedHeaderFields.size() == 0) { parseResponse(pair.value()); } else { addMapProperties(pair, map); } orderedHeaderFields.add(pair); if ("Content-Length".equalsIgnoreCase(pair.key())) { contentLength = Integer.parseInt(pair.value()); } else if ("Content-Type".equalsIgnoreCase(pair.key())) { contentType = pair.value(); } else if ("Content-Encoding".equalsIgnoreCase(pair.key())) { contentEncoding = pair.value(); } else if ("Transfer-Encoding".equalsIgnoreCase(pair.key())) { chunked = pair.value().equalsIgnoreCase("chunked"); } else if ("Date".equalsIgnoreCase(pair.key())) { try { date = RFC822.valueOf(pair.value()); } catch (IllegalArgumentException e) { } } else if ("Last-Modified".equalsIgnoreCase(pair.key())) { try { lastModified = RFC822.valueOf(pair.value()); } catch (IllegalArgumentException e) { } } else if ("Expires".equalsIgnoreCase(pair.key())) { try { expiration = RFC822.valueOf(pair.value()); } catch (IllegalArgumentException e) { } } } if (chunked) { HSBufferedOutputStream buffer = new HSBufferedOutputStream(); int len = 0; while((len = readChunkLength(in)) > 0) { byte[] bytes = new byte[len]; assume(len == in.read(bytes, 0, len)); buffer.write(bytes); assume(in.read() == '\r'); assume(in.read() == '\n'); } responseData = new HSBufferedInputStream(buffer); contentLength = ((HSBufferedInputStream)responseData).size(); // // Read footers (if any) // while((pair = readKVP(in)) != null) { orderedHeaderFields.add(pair); addMapProperties(pair, map); } } else { byte[] bytes = new byte[contentLength]; for (int offset=0; offset < contentLength; ) { offset += in.read(bytes, offset, contentLength - offset); } responseData = new HSBufferedInputStream(bytes); } headerFields = Collections.unmodifiableMap(map); } finally { gotResponse = true; if (headerFields == null || "Close".equalsIgnoreCase(getHeaderField("Connection"))) { disconnect(); } } } // Private private void write(KVP header) throws IOException { write(header.toString()); write(CRLF); } private void write(String s) throws IOException { write(s.getBytes("US-ASCII")); } private void write(int ch) throws IOException { socket.getOutputStream().write(ch); } private void write(byte[] bytes) throws IOException { write(bytes, 0, bytes.length); } private void write(byte[] bytes, int offset, int len) throws IOException { socket.getOutputStream().write(bytes, offset, len); socket.getOutputStream().flush(); } /** * Parse the HTTP response line. */ private void parseResponse(String line) throws IllegalArgumentException { StringTokenizer tok = new StringTokenizer(line, " "); if (tok.countTokens() < 2) { throw new IllegalArgumentException(line); } String httpVersion = tok.nextToken(); responseCode = Integer.parseInt(tok.nextToken()); if (tok.hasMoreTokens()) { responseMessage = tok.nextToken("\r\n"); } } /** * Read a line ending in CRLF that indicates the length of the next chunk. */ private int readChunkLength(InputStream in) throws IOException { StringBuffer sb = new StringBuffer(); boolean done = false; boolean cr = false; while(!done) { int ch = in.read(); switch(ch) { case -1: throw new IOException("Connection was closed!"); case '\r': if (sb.length() == 0) { cr = true; } break; case '\n': if (cr) { done = true; } break; default: sb.append((char)ch); break; } } String line = sb.toString(); int ptr = line.indexOf(";"); if (ptr > 0) { return Integer.parseInt(line.substring(0,ptr), 16); } else { return Integer.parseInt(line, 16); } } /** * Like assert, but always enabled. */ private void assume(boolean test) throws AssertionError { if (!test) throw new AssertionError(); } /** * ByteArrayOutputStream that provides access to the underlying memory buffer. */ class HSBufferedOutputStream extends ByteArrayOutputStream { HSBufferedOutputStream() { super(); } /** * Access the underlying buffer -- NOT A COPY. */ byte[] getBuf() { return buf; } } /** * InputStream that resets this connection when closed. */ class HSBufferedInputStream extends ByteArrayInputStream { private boolean closed; HSBufferedInputStream(HSBufferedOutputStream out) { super(out.getBuf(), 0, out.size()); closed = false; } HSBufferedInputStream(byte[] buffer) { super(buffer); closed = false; } int size() { return count; } @Override public void close() throws IOException { if (!closed) { super.close(); reset(); closed = true; } } } /** * An OutputStream for a fixed-length stream. */ class HSOutputStream extends OutputStream { private int size; int ptr; boolean closed; HSOutputStream(int size) { this.size = size; ptr = 0; closed = false; } boolean complete() { return remaining() == 0; } int remaining() { return size - ptr; } // InputStream overrides @Override public void write(int ch) throws IOException { if (closed) throw new IOException("stream closed"); ptr++; if (ptr > size) { throw new IOException("Buffer overflow " + ptr); } HttpSocketConnection.this.write(ch); } @Override public void write(byte[] b) throws IOException { if (closed) throw new IOException("stream closed"); write(b, 0, b.length); } @Override public void write(byte[] b, int off, int len) throws IOException { if (closed) throw new IOException("stream closed"); ptr = ptr + len; if (ptr > size) { throw new IOException("Buffer overflow " + ptr + ", size=" + size); } HttpSocketConnection.this.write(b, off, len); } @Override public void flush() throws IOException { if (closed) throw new IOException("stream closed"); socket.getOutputStream().flush(); } @Override public void close() throws IOException { if (!closed) { if (complete()) { flush(); closed = true; } else { throw new IOException("You need to write " + remaining() + " more bytes!"); } } } } /** * A safe place (i.e., nowhere) to write output in the event of a failure to establish a CONNECT tunnel through * an HTTP proxy. */ class HSDevNull extends HSOutputStream { HSDevNull() { super(0); } @Override public void write(int ch) {} @Override public void write(byte[] b) {} @Override public void write(byte[] b, int offset, int len) {} @Override public void flush() {} @Override public void close() {} } /** * An OutputStream for chunked stream encoding. */ class HSChunkedOutputStream extends HSOutputStream { private byte[] buffer; HSChunkedOutputStream(int chunkSize) { super(chunkSize); buffer = new byte[chunkSize]; } @Override boolean complete() { return ptr == 0; } @Override public void write(int ch) throws IOException { if (closed) throw new IOException("stream closed"); int end = ptr + 1; if (end <= buffer.length) { buffer[ptr++] = (byte)(ch & 0xFF); if (end == buffer.length) { flush(); } } else { throw new IOException("Buffer overrun " + ptr); } } @Override public void write(byte[] buff) throws IOException { if (closed) throw new IOException("stream closed"); write(buff, 0, buff.length); } @Override public void write(byte[] buff, int offset, int len) throws IOException { if (closed) throw new IOException("stream closed"); len = Math.min(buff.length - offset, len); int end = ptr + len; if (end <= buffer.length) { System.arraycopy(buff, offset, buffer, ptr, len); ptr = end; if (end == buffer.length) { flush(); } } else { int remainder = buffer.length - ptr; write(buff, offset, remainder); write(buff, offset + remainder, len - remainder); } } @Override public void flush() throws IOException { if (closed) throw new IOException("stream closed"); if (ptr > 0) { HttpSocketConnection.this.write(Integer.toHexString(ptr)); HttpSocketConnection.this.write(CRLF); HttpSocketConnection.this.write(buffer, 0, ptr); HttpSocketConnection.this.write(CRLF); buffer = new byte[buffer.length]; ptr = 0; } } @Override public void close() throws IOException { if (!closed) { flush(); HttpSocketConnection.this.write("0"); HttpSocketConnection.this.write(CRLF); HttpSocketConnection.this.write(CRLF); closed = true; } } } /** * An output stream to nowhere. */ class DevNull extends OutputStream { DevNull() {} @Override public void write(int ch) {} } }