/*
* This file is part of the OWASP Proxy, a free intercepting proxy library.
* Copyright (C) 2008-2010 Rogan Dawes <rogan@dawes.za.net>
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to:
* The Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
*/
package org.owasp.proxy.http.client;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.ProxySelector;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import org.owasp.proxy.daemon.AddressResolver;
import org.owasp.proxy.http.MessageFormatException;
import org.owasp.proxy.http.MessageUtils;
import org.owasp.proxy.http.MutableRequestHeader;
import org.owasp.proxy.http.MutableResponseHeader;
import org.owasp.proxy.http.StreamingRequest;
import org.owasp.proxy.http.StreamingResponse;
import org.owasp.proxy.io.ChunkedInputStream;
import org.owasp.proxy.io.EofNotifyingInputStream;
import org.owasp.proxy.io.FixedLengthInputStream;
import org.owasp.proxy.io.TimingInputStream;
import org.owasp.proxy.ssl.DefaultClientContextSelector;
import org.owasp.proxy.ssl.SSLContextSelector;
import org.owasp.proxy.util.AsciiString;
public class HttpClient {
private static Logger logger = Logger.getLogger(HttpClient.class.getName());
public static final ProxySelector NO_PROXY = new ProxySelector() {
@Override
public void connectFailed(URI uri, SocketAddress sa, IOException ioe) {
}
@Override
public List<Proxy> select(URI uri) {
return Arrays.asList(Proxy.NO_PROXY);
}
};
private static final InputStream NO_CONTENT = new ByteArrayInputStream(
new byte[0]);
public enum State {
DISCONNECTED, CONNECTED, REQUEST_HEADER_SENT, REQUEST_CONTENT_SENT, RESPONSE_HEADER_READ, RESPONSE_CONTINUE, RESPONSE_CONTENT_IN_PROGRESS, RESPONSE_CONTENT_READ
}
private SSLContextSelector contextSelector = new DefaultClientContextSelector();
private ProxySelector proxySelector = null;
private AddressResolver resolver = null;
protected Socket socket = null;
private InetSocketAddress target = null;
private boolean direct = true;
private State state = State.DISCONNECTED;
private boolean expectResponseContent;
private InputStream responseContent = null;
private long requestSubmissionTime, responseHeaderStartTime,
responseHeaderEndTime;
private int soTimeout = 10000;
private String[] enabledProtocols = { "TLSv1", "SSLv3" };
public HttpClient() {
String s = System.getProperty("https.protocols");
if (s != null && s.length() > 0) {
String[] split = s.split(",");
List<String> protos = new ArrayList<String>();
for (String proto : split) {
String candidate = proto.trim();
if (candidate.length() > 0) {
protos.add(candidate);
}
}
if (protos.size() != 0) {
enabledProtocols = protos.toArray(new String[protos.size()]);
}
}
}
public void setProxySelector(ProxySelector proxySelector) {
this.proxySelector = proxySelector;
}
public ProxySelector getProxySelector() {
if (proxySelector == null)
return NO_PROXY;
return proxySelector;
}
public void setSslEnabledProtocols(String[] protocols) {
if (protocols == null) {
enabledProtocols = null;
} else {
enabledProtocols = new String[protocols.length];
System.arraycopy(protocols, 0, enabledProtocols, 0,
protocols.length);
}
}
public String[] getSslEnabledProtocols() {
if (enabledProtocols == null)
return null;
String[] protocols = new String[enabledProtocols.length];
System.arraycopy(enabledProtocols, 0, protocols, 0,
enabledProtocols.length);
return protocols;
}
public void setSoTimeout(int timeout) {
this.soTimeout = timeout;
}
public int getSoTimeout() {
return soTimeout;
}
public void setSslContextSelector(SSLContextSelector contextSelector) {
this.contextSelector = contextSelector;
}
public void setAddressResolver(AddressResolver resolver) {
this.resolver = resolver;
}
public State getState() {
return state;
}
protected void validateTarget(SocketAddress target) throws IOException {
}
private URI constructUri(boolean ssl, String host, int port)
throws IOException {
StringBuilder buff = new StringBuilder();
if (ssl) {
buff.append("https");
} else {
buff.append("http");
}
buff.append("://").append(host).append(":").append(port);
try {
return new URI(buff.toString());
} catch (URISyntaxException use) {
IOException ioe = new IOException("Unable to construct a URI");
ioe.initCause(use);
throw ioe;
}
}
private boolean isConnected(InetSocketAddress target) {
if (socket == null || socket.isClosed() || socket.isInputShutdown()) {
return false;
}
if (target.equals(this.target)) {
try {
// FIXME: This only works because we don't implement pipelining!
int oldtimeout = socket.getSoTimeout();
try {
socket.setSoTimeout(1);
byte[] buff = new byte[1024];
int got = socket.getInputStream().read(buff);
if (got == -1) {
return false;
}
if (got > 0) {
logger.warning("Unexpected data read from socket ("
+ got + " bytes):\n"
+ AsciiString.create(buff, 0, got));
socket.close();
return false;
}
} catch (SocketTimeoutException e) {
return true;
} finally {
socket.setSoTimeout(oldtimeout);
}
} catch (IOException ioe) {
logger.fine("Connection looks closed! Opening a new one");
return false;
}
}
return false;
}
private StreamingResponse proxyConnect(InetSocketAddress target)
throws IOException, MessageFormatException {
MutableRequestHeader req = new MutableRequestHeader.Impl();
req.setStartLine("CONNECT " + target.getHostName() + ":"
+ target.getPort() + " HTTP/1.0");
OutputStream out = socket.getOutputStream();
out.write(req.getHeader());
out.flush();
return readResponse(socket.getInputStream());
}
public StreamingResponse connect(String host, int port, boolean ssl)
throws IOException {
return connect(new InetSocketAddress(host, port), ssl);
}
public StreamingResponse connect(InetSocketAddress target, boolean ssl)
throws IOException {
if (resolver != null) {
InetAddress addr = resolver.getAddress(target.getHostName());
target = new InetSocketAddress(addr, target.getPort());
}
if (target.isUnresolved()) {
target = new InetSocketAddress(target.getHostName(), target
.getPort());
}
URI uri = constructUri(ssl, target.getHostName(), target.getPort());
List<Proxy> proxies = getProxySelector().select(uri);
if (isConnected(target)) {
if (state == State.RESPONSE_CONTENT_READ
|| state == State.CONNECTED) {
return null;
}
disconnect();
} else if (socket != null && !socket.isClosed()) {
try {
socket.close();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
this.target = target;
socket = null;
IOException lastAttempt = null;
for (Proxy proxy : proxies) {
direct = true;
SocketAddress addr = proxy == Proxy.NO_PROXY ? target : proxy
.address();
try {
validateTarget(addr);
if (proxy.type() == Proxy.Type.HTTP) {
socket = new Socket(Proxy.NO_PROXY);
socket.setSoTimeout(soTimeout);
socket.connect(addr);
if (ssl) {
try {
StreamingResponse proxyResponse = proxyConnect(target);
if (!"200".equals(proxyResponse.getStatus())) {
return proxyResponse;
}
} catch (MessageFormatException mfe) {
IOException ioe = new IOException(
"Malformed proxy response");
ioe.initCause(mfe);
throw ioe;
}
layerSsl(target);
} else {
direct = false;
}
} else {
socket = new Socket(proxy);
socket.setSoTimeout(soTimeout);
socket.connect(target);
if (ssl) {
layerSsl(target);
}
}
} catch (IOException ioe) {
getProxySelector().connectFailed(uri, addr, ioe);
lastAttempt = ioe;
if (socket != null) {
socket.close();
socket = null;
}
}
if (socket != null && socket.isConnected()) {
// success
state = State.CONNECTED;
return null;
}
}
if (lastAttempt != null) {
throw lastAttempt;
}
throw new IOException("Couldn't connect to server");
}
private void layerSsl(InetSocketAddress target) throws IOException {
if (contextSelector == null) {
throw new IllegalStateException(
"SSL Context Selector is null, SSL is not supported!");
}
SSLContext sslContext = contextSelector.select(target);
SSLSocketFactory factory = sslContext.getSocketFactory();
SSLSocket sslsocket = (SSLSocket) factory.createSocket(socket, socket
.getInetAddress().getHostName(), socket.getPort(), true);
sslsocket.setEnabledProtocols(enabledProtocols); // should be set by settings
sslsocket.setUseClientMode(true);
sslsocket.setSoTimeout(soTimeout);
sslsocket.startHandshake();
socket = sslsocket;
}
public void sendRequestHeader(byte[] header) throws IOException,
MessageFormatException {
if (state == State.RESPONSE_CONTINUE) {
throw new IllegalStateException(
"Cannot start a new request when the "
+ "previous request content has not yet been sent");
}
if (state != State.CONNECTED && state != State.RESPONSE_CONTENT_READ) {
throw new IllegalStateException(
"Illegal state. Can't send request headers when state is "
+ state);
}
OutputStream os = new BufferedOutputStream(socket.getOutputStream());
int resourceStart = -1;
String method = null;
for (int i = 0; i < header.length; i++) {
if (method == null && Character.isWhitespace(header[i])) {
method = AsciiString.create(header, 0, i - 1);
}
if (method != null && !Character.isWhitespace(header[i])
&& resourceStart == -1) {
resourceStart = i;
break;
}
if (header[i] == '\r' || header[i] == '\n') {
throw new MessageFormatException(
"Encountered CR or LF when parsing the URI!", header);
}
}
expectResponseContent = !"HEAD".equals(method);
if (!direct) {
if (resourceStart > 0) {
os.write(header, 0, resourceStart);
os.write(("http://" + target.getHostName() + ":" + target
.getPort()).getBytes());
os.write(header, resourceStart, header.length - resourceStart);
} else {
throw new MessageFormatException("Couldn't parse the URI!",
header);
}
} else {
os.write(header);
}
os.flush();
state = State.REQUEST_HEADER_SENT;
requestSubmissionTime = System.currentTimeMillis();
}
public void sendRequestContent(byte[] content) throws IOException {
if (content != null) {
sendRequestContent(new ByteArrayInputStream(content));
} else {
sendRequestContent((InputStream) null);
}
}
public void sendRequestContent(InputStream content) throws IOException {
if (state != State.REQUEST_HEADER_SENT
&& state != State.RESPONSE_CONTINUE) {
throw new IllegalStateException(
"Ilegal state. Can't send request content when state is "
+ state);
}
if (content != null) {
OutputStream os = socket.getOutputStream();
byte[] buff = new byte[1024];
int got;
while ((got = content.read(buff)) > 0) {
os.write(buff, 0, got);
}
os.flush();
} else if (state == State.RESPONSE_CONTINUE)
throw new IllegalStateException(
"Cannot send null content after a 100 Continue response!");
state = State.REQUEST_CONTENT_SENT;
requestSubmissionTime = System.currentTimeMillis();
}
private StreamingResponse readResponse(InputStream in) throws IOException,
MessageFormatException {
InputStream is = socket.getInputStream();
HeaderByteArrayOutputStream header = new HeaderByteArrayOutputStream();
StreamingResponse response = new StreamingResponse.Impl();
int i = -1;
try {
while (!header.isEndOfHeader() && (i = is.read()) > -1) {
header.write(i);
}
response.setHeaderTime(System.currentTimeMillis());
} catch (SocketTimeoutException ste) {
if (header.size() > 0) {
MessageFormatException mfe = new MessageFormatException(
"Timeout reading response header", header.toByteArray());
mfe.initCause(ste);
throw mfe;
}
throw ste;
}
if (!header.isEndOfHeader() && i == -1) {
if (header.size() > 0)
throw new MessageFormatException("Invalid header ", header
.toByteArray());
throw new IOException("Unexpected end of stream reading header");
}
response.setHeader(header.toByteArray());
response.setContent(is);
return response;
}
/**
* returns the bytes of the response header.
*
* NB: The response header may be a "100 Continue" response. Callers MUST check if the response code is "100", and
* be prepared to call getHeader() again to retrieve the real response headers, BEFORE calling getResponseContent().
*
* @return
* @throws IOException
* @throws MessageFormatException
*/
public byte[] getResponseHeader() throws IOException,
MessageFormatException {
if (state != State.REQUEST_HEADER_SENT
&& state != State.REQUEST_CONTENT_SENT) {
throw new IllegalStateException(
"Ilegal state. Can't read response header when state is "
+ state);
}
InputStream is = socket.getInputStream();
HeaderByteArrayOutputStream header = new HeaderByteArrayOutputStream();
int i = -1;
try {
responseHeaderStartTime = responseHeaderEndTime = 0;
while (!header.isEndOfHeader() && (i = is.read()) > -1) {
if (responseHeaderStartTime == 0) {
responseHeaderStartTime = System.currentTimeMillis();
}
header.write(i);
}
responseHeaderEndTime = System.currentTimeMillis();
} catch (SocketTimeoutException ste) {
logger.fine("Timeout reading response header. Had read "
+ header.size() + " bytes");
if (header.size() > 0) {
logger.fine(AsciiString.create(header.toByteArray()));
}
throw ste;
}
if (i == -1) {
throw new IOException("Unexpected end of stream reading header");
}
MutableResponseHeader.Impl rh = new MutableResponseHeader.Impl();
rh.setHeader(header.toByteArray());
String status = rh.getStatus();
if (status.equals("100")) {
state = State.RESPONSE_CONTINUE;
} else {
state = State.RESPONSE_HEADER_READ;
responseContent = getContentStream(rh, is);
}
return rh.getHeader();
}
private InputStream getContentStream(MutableResponseHeader header,
InputStream is) throws IOException, MessageFormatException {
String status = header.getStatus();
if ("204".equals(status) || "304".equals(status)
|| !expectResponseContent) {
return NO_CONTENT;
} else {
String transferCoding = header.getHeader("Transfer-Encoding");
String contentLength = header.getHeader("Content-Length");
if (transferCoding != null
&& transferCoding.trim().equalsIgnoreCase("chunked")) {
is = new ChunkedInputStream(is, true);
} else if (contentLength != null) {
try {
is = new FixedLengthInputStream(is, Integer
.parseInt(contentLength.trim()));
} catch (NumberFormatException nfe) {
IOException ioe = new IOException(
"Invalid content-length header: " + contentLength);
ioe.initCause(nfe);
throw ioe;
}
}
return is;
}
}
public InputStream getResponseContent() throws IOException {
if (state == State.RESPONSE_CONTINUE) {
return null;
}
if (state != State.RESPONSE_HEADER_READ) {
throw new IllegalStateException(
"Illegal state. Can't read response body when state is "
+ state);
}
state = State.RESPONSE_CONTENT_IN_PROGRESS;
return new EofNotifyingInputStream(responseContent) {
@Override
public void eof() {
state = State.RESPONSE_CONTENT_READ;
}
};
}
public void disconnect() throws IOException {
try {
if (socket != null && !socket.isClosed()) {
socket.close();
}
} finally {
socket = null;
state = State.DISCONNECTED;
}
}
public long getRequestTime() {
return requestSubmissionTime;
}
public long getResponseHeaderStartTime() {
return responseHeaderStartTime;
}
public long getResponseHeaderEndTime() {
return responseHeaderEndTime;
}
public StreamingResponse fetchResponse(StreamingRequest request)
throws IOException, MessageFormatException {
StreamingResponse response = connect(request.getTarget(), request
.isSsl());
if (response != null)
return response;
response = new StreamingResponse.Impl();
sendRequestHeader(request.getHeader());
request.setTime(getRequestTime());
if (MessageUtils.isExpectContinue(request)) {
socket.setSoTimeout(2000);
try {
response.setHeader(getResponseHeader());
response.setHeaderTime(getResponseHeaderEndTime());
} catch (SocketTimeoutException ste) {
} finally {
socket.setSoTimeout(getSoTimeout());
}
if (response.getHeader() != null
&& !"100".equals(response.getStatus())) {
InputStream content = getResponseContent();
if (content != null)
content = new TimingInputStream(content, response);
response.setContent(content);
return response;
}
if (request.getContent() != null) {
sendRequestContent(request.getContent());
request.setTime(getRequestTime());
}
} else {
if (request.getContent() != null)
sendRequestContent(request.getContent());
request.setTime(System.currentTimeMillis());
}
if (response.getHeader() != null) {
byte[] cont = response.getHeader();
byte[] header = getResponseHeader();
response.setHeaderTime(getResponseHeaderEndTime());
byte[] newHeader = new byte[cont.length + header.length];
System.arraycopy(cont, 0, newHeader, 0, cont.length);
System.arraycopy(header, 0, newHeader, cont.length, header.length);
response.setHeader(newHeader);
} else {
response.setHeader(getResponseHeader());
response.setHeaderTime(getResponseHeaderEndTime());
}
InputStream content = getResponseContent();
if (content != null)
content = new TimingInputStream(content, response);
response.setContent(content);
return response;
}
private static class HeaderByteArrayOutputStream extends
ByteArrayOutputStream {
// we do it here because we have direct access to the buffer
public boolean isEndOfHeader() {
int i = count;
return i > 4 && buf[i - 4] == '\r' && buf[i - 3] == '\n'
&& buf[i - 2] == '\r' && buf[i - 1] == '\n';
}
}
}