/*
* This file is part of the OWASP Proxy, a free intercepting proxy library.
* Copyright (C) 2008-2010 Rogan Dawes <rogan@dawes.za.net>
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to:
* The Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
*/
package org.owasp.proxy.ajp;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import org.owasp.proxy.daemon.ConnectionHandler;
import org.owasp.proxy.http.MessageFormatException;
import org.owasp.proxy.http.MutableResponseHeader;
import org.owasp.proxy.http.NamedValue;
import org.owasp.proxy.http.StreamingResponse;
import org.owasp.proxy.io.BufferedInputStream;
import org.owasp.proxy.io.EofNotifyingInputStream;
public class AJPConnectionHandler implements ConnectionHandler {
private static final byte[] PONG;
static {
AJPMessage msg = new AJPMessage(16);
msg.reset();
msg.appendByte(AJPConstants.JK_AJP13_CPONG_REPLY);
msg.endServerMessage();
PONG = msg.toByteArray();
}
private AJPRequestHandler handler;
private int timeout = 30000;
public AJPConnectionHandler(AJPRequestHandler handler) {
this.handler = handler;
}
/**
* @return the timeout
*/
public int getTimeout() {
return timeout;
}
/**
* @param timeout
* the timeout to set
*/
public void setTimeout(int timeout) {
this.timeout = timeout;
}
/*
* (non-Javadoc)
*
* @see org.owasp.proxy.daemon.ConnectionHandler#handleConnection(java.net.Socket )
*/
public void handleConnection(Socket socket) throws IOException {
socket.setSoTimeout(timeout);
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
AJPMessage ajpRequest = new AJPMessage(8192);
final StateHolder holder = new StateHolder();
do {
if (!holder.state.equals(State.READY))
throw new IllegalStateException(
"Trying to read a new request in state " + holder.state);
ajpRequest.readMessage(in);
int type = ajpRequest.peekByte();
switch (type) {
case AJPConstants.JK_AJP13_CPING_REQUEST:
out.write(PONG);
out.flush();
case AJPConstants.JK_AJP13_FORWARD_REQUEST:
doForwardRequest(socket, ajpRequest, holder);
}
} while (true);
}
private void doForwardRequest(Socket socket, AJPMessage ajpRequest,
final StateHolder holder) throws IOException {
InputStream in = socket.getInputStream();
OutputStream out = socket.getOutputStream();
InetAddress source = socket.getInetAddress();
try {
AJPRequest request = new AJPRequest();
translate(ajpRequest, request);
String cl = request.getHeader("Content-Length");
if (cl != null) {
long len = Long.parseLong(cl);
if (len >= 0 && len < Integer.MAX_VALUE) {
InputStream content = new AJPInputStream(in, out,
(int) len, ajpRequest);
content = new EofNotifyingInputStream(content) {
protected void eof() {
holder.state = State.RESPONSE_HEADER;
}
};
request.setContent(content);
holder.state = State.REQUEST_CONTENT;
} else {
throw new IllegalArgumentException(
"Invalid Content-Length: " + cl);
}
} else {
holder.state = State.RESPONSE_HEADER;
}
StreamingResponse response = handler.handleRequest(source, request);
if (holder.state != State.RESPONSE_HEADER)
throw new IllegalStateException(
"handler did not read all the request content");
AJPMessage ajpResponse = new AJPMessage(8192);
translate(response, ajpResponse);
ajpResponse.write(out);
InputStream content = response.getContent();
if (content != null) {
int wrote;
do {
ajpResponse.reset();
ajpResponse
.appendByte(AJPConstants.JK_AJP13_SEND_BODY_CHUNK);
wrote = ajpResponse.appendBytes(content,
AJPConstants.MAX_SEND_SIZE);
if (wrote == 0)
break;
ajpResponse.endServerMessage();
ajpResponse.write(out);
} while (wrote > 0);
}
ajpResponse.reset();
ajpResponse.appendByte(AJPConstants.JK_AJP13_END_RESPONSE);
ajpResponse.endServerMessage();
ajpResponse.write(out);
holder.state = State.READY;
} catch (MessageFormatException mfe) {
mfe.printStackTrace();
return;
}
}
private static void translate(AJPMessage ajp, AJPRequest request)
throws MessageFormatException {
request.setHeader(null);
if (ajp.getByte() != AJPConstants.JK_AJP13_FORWARD_REQUEST)
throw new IllegalStateException(
"Can't translate this message into a request");
request.setMethod(AJPConstants.getRequestMethod(ajp.getByte()));
if (request.getMethod() == null)
throw new RuntimeException("Unsupported request method");
request.setVersion(ajp.getString());
request.setResource(ajp.getString());
request.setRemoteAddress(ajp.getString());
request.setRemoteHost(ajp.getString());
String target = ajp.getString();
int port = ajp.getInt();
request.setTarget(new InetSocketAddress(target, port));
request.setSsl(ajp.getBoolean());
getHeaders(ajp, request);
getRequestAttributes(ajp, request);
}
private static void getHeaders(AJPMessage ajp, AJPRequest request)
throws MessageFormatException {
int len = ajp.getInt();
for (int i = 0; i < len; i++) {
byte coded = ajp.peekByte();
String name;
if (coded == (byte) 0xA0) {
name = AJPConstants.getRequestHeader(ajp.getInt());
} else {
name = ajp.getString();
}
String value = ajp.getString();
request.addHeader(name, value);
}
}
private static void getRequestAttributes(AJPMessage ajp, AJPRequest request)
throws MessageFormatException {
byte attr = ajp.getByte();
for (; attr != AJPConstants.SC_A_ARE_DONE; attr = ajp.getByte()) {
switch (attr) {
case AJPConstants.SC_A_CONTEXT:
request.setContext(ajp.getString());
break;
case AJPConstants.SC_A_SERVLET_PATH:
request.setServletPath(ajp.getString());
break;
case AJPConstants.SC_A_REMOTE_USER:
request.setRemoteUser(ajp.getString());
break;
case AJPConstants.SC_A_AUTH_TYPE:
request.setAuthType(ajp.getString());
break;
case AJPConstants.SC_A_QUERY_STRING:
request.setResource(request.getResource() + "?"
+ ajp.getString());
break;
case AJPConstants.SC_A_JVM_ROUTE:
request.setRoute(ajp.getString());
break;
case AJPConstants.SC_A_SSL_CERT:
request.setSslCert(ajp.getString());
break;
case AJPConstants.SC_A_SSL_CIPHER:
request.setSslCipher(ajp.getString());
break;
case AJPConstants.SC_A_SECRET:
request.setSecret(ajp.getString());
break;
case AJPConstants.SC_A_SSL_SESSION:
request.setSslSession(ajp.getString());
break;
case AJPConstants.SC_A_REQ_ATTRIBUTE:
request.getRequestAttributes().put(ajp.getString(),
ajp.getString());
break;
case AJPConstants.SC_A_SSL_KEY_SIZE:
request.setSslKeySize(ajp.getString());
break;
case AJPConstants.SC_A_STORED_METHOD:
request.setStoredMethod(ajp.getString());
break;
default:
System.out.println("Unexpected request attribute: " + attr
+ ": value was '" + ajp.getString() + "'");
}
}
}
private static void translate(MutableResponseHeader response,
AJPMessage message) throws MessageFormatException {
message.reset();
message.appendByte(AJPConstants.JK_AJP13_SEND_HEADERS);
message.appendInt(Integer.parseInt(response.getStatus()));
message.appendString(response.getReason());
NamedValue[] headers = response.getHeaders();
if (headers == null)
headers = new NamedValue[0];
message.appendInt(headers.length);
for (int i = 0; i < headers.length; i++) {
int code = AJPConstants
.getResponseHeaderIndex(headers[i].getName());
if (code > 0) {
message.appendInt(code);
} else {
message.appendString(headers[i].getName());
}
message.appendString(headers[i].getValue());
}
message.endServerMessage();
}
private static class AJPInputStream extends BufferedInputStream {
private OutputStream out;
private int len, read = 0;
private AJPMessage request, response;
public AJPInputStream(InputStream in, OutputStream out, int len,
AJPMessage request) throws IOException {
super(in, 8192);
this.out = out;
this.len = len;
this.request = request;
this.response = new AJPMessage(16);
}
protected void fillBuffer() throws IOException {
if (read >= len) {
buff = null;
return;
}
request.readMessage(in);
int size = request.peekInt();
if (buff.length < size)
buff = new byte[size];
size = request.getBytes(buff);
read += size;
if (read > len)
throw new IllegalStateException(
"Request body packet sizes mismatch! Expected " + len
+ ", got " + read);
start = 0;
end = size;
if (read < len) { // ask for more
response.reset();
response.appendByte(AJPConstants.JK_AJP13_GET_BODY_CHUNK);
response.appendInt(Math.min(len - read,
AJPConstants.MAX_READ_SIZE));
response.endServerMessage();
response.write(out);
}
}
}
private static class StateHolder {
public State state = State.READY;
}
private enum State {
READY, REQUEST_HEADER, REQUEST_CONTENT, RESPONSE_HEADER, RESPONSE_CONTENT
}
}