/**
*
*/
package io.client.thrift;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.Socket;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.SocketFactory;
/**
* @author HouKangxi
*
*/
public class ClientInterfaceFactory {
private ClientInterfaceFactory() {
}
private static ConcurrentHashMap<Long, Object> ifaceCache = new ConcurrentHashMap<Long, Object>();
/**
* 获得与服务端通信的接口对象
* <p>
* 调用者可以实现自定义的
* SocketFactory来内部配置Socket参数(如超时时间,SSL等),也可以通过返回包装的Socket来实现连接池,
* 也可以使用内置的连接池类:{@link io.client.thrift.pool.SocketConnectionPool} <br/>
* 使用例子:<br/>
* {@code SocketFactory tcpfac = new TcpSocketFactory("localhost", 8080);}
* <br/>
* {@code SocketFactory pool = new SocketConnectionPool(tcpfac);} <br/>
* {@code SomeIface service = ClientInterfaceFactory.getClientInterface(SomeIface.class, pool); }
* <br/>
*
*
* @param ifaceClass
* - 接口class
* @param factory
* - 套接字工厂类, 注意:需要实现 createSocket() 方法,需要实现hashCode()方法来区分factory
* @return 接口对象
*/
@SuppressWarnings("unchecked")
public static <INTERFACE> INTERFACE getClientInterface(Class<INTERFACE> ifaceClass, SocketFactory factory) {
long part1 = ifaceClass.getName().hashCode();
final Long KEY = (part1 << 32) | factory.hashCode();
INTERFACE iface = (INTERFACE) ifaceCache.get(KEY);
if (iface == null) {
iface = (INTERFACE) Proxy.newProxyInstance(ifaceClass.getClassLoader(), new Class[] { ifaceClass },
new Handler(factory));
ifaceCache.putIfAbsent(KEY, iface);
}
return iface;
}
private static class Handler implements InvocationHandler {
final AtomicInteger seqIdHolder = new AtomicInteger(0);
final SocketFactory factory;
public Handler(SocketFactory factory) {
this.factory = factory;
}
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String methodName = method.getName();
if (args == null || args.length == 0) {
if (methodName.equals("toString")) {
return Handler.class.getName() + "@" + System.identityHashCode(this);
}
if (methodName.equals("hashCode")) {
return System.identityHashCode(this);
}
}
int seqId = seqIdHolder.incrementAndGet();
ByteArrayOutputStream outbuff = new ByteArrayOutputStream();
TCompactProtocol protocol = new TCompactProtocol(outbuff, null);
ProtocolIOUtil.write(methodName, seqId, protocol, method.getGenericParameterTypes(), args);
Socket connection = null;
Object rs = null;
boolean success = true;
try {
byte[] frame;
{
byte[] arrContent = outbuff.toByteArray();
final int msgLen = arrContent.length;
// System.out.printf("*** 客户端 msgLen = %d, time=%d,
// connection = %s\n", msgLen, System.currentTimeMillis(),
// connection);
frame = new byte[4 + msgLen];// 前四个字节代表消息长度
frame[0] = (byte) (msgLen >> 24);
frame[1] = (byte) ((msgLen >> 16) & 0xff);
frame[2] = (byte) ((msgLen >> 8) & 0xff);
frame[3] = (byte) (msgLen & 0xff);
// System.out.printf("** arrayLen = [%d, %d, %d, %d]\n",
// arr4Req[0], arr4Req[1], arr4Req[2], arr4Req[3]);
System.arraycopy(arrContent, 0, frame, 4, msgLen);
}
connection = factory.createSocket();
OutputStream out = connection.getOutputStream();
out.write(frame);
out.flush();
InputStream in = connection.getInputStream();
if (in != null) {
// int readLen = 0, offset = 0;
// while (readLen < 4) {
// readLen += in.read(arrLen, offset, 4 - readLen);
// }
int readLen = in.read(frame, 0, 4);
if (readLen == 1) {
readLen = in.read(frame, 1, 4);
// System.out.printf("** respArrayLen(!1) = [%d, %d, %d,
// %d]\n", arr4Req[1], arr4Req[2],
// arr4Req[3], arr4Req[4]);
} /*
* else if (readLen == 4) { System.out.printf(
* "** respArrayLen = [%d, %d, %d, %d]\n", arr4Req[0],
* arr4Req[1], arr4Req[2], arr4Req[3]); }
*/
// System.out.println("readLen=" + readLen + ",connection =
// " + connection);
if (readLen == 4) {
// 此时arrLen代表返回结果的长度
protocol.transIn = in;
rs = ProtocolIOUtil.read(protocol, method.getGenericReturnType(), method.getExceptionTypes(),
seqId);
} /*
* else { System.out.println("arr[0]=" + arr4Req[0] +
* ", 出错的socket: " + connection); }
*/
}
} catch (IOException ex) {
success = false;
throw ex;
} catch (Throwable ex) {
success = false;
throw ex;
} finally {
if (connection != null) {
if (success) {
// 正常情况,通过socket.close()关闭,方便切换到定制业务
connection.close();
} else {
// 异常情况,直接通过IO流关闭
try {
connection.getOutputStream().close();
connection.getInputStream().close();
} catch (Throwable e) {
}
}
}
}
return rs;
}
}
}