/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.nettythrift.protocol;
import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.util.LinkedList;
import java.util.Map;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.meta_data.ListMetaData;
import org.apache.thrift.meta_data.MapMetaData;
import org.apache.thrift.meta_data.SetMetaData;
import org.apache.thrift.meta_data.StructMetaData;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TStruct;
import org.apache.thrift.protocol.TType;
import org.apache.thrift.transport.TTransport;
import io.nettythrift.utils.json.ArrayJson;
import io.nettythrift.utils.json.BaseArray;
import io.nettythrift.utils.json.JSONArray;
/**
* JSON protocol implementation for thrift.
*
* This protocol is read-write. It should not be confused with the
* TJSONProtocol.
* <p>
*
* Changes: <br/>
* 只写改为可读写 - by Houkx
*/
@SuppressWarnings("rawtypes")
public class TSimpleJSONProtocol extends TProtocol {
/**
* Factory
*/
@SuppressWarnings("serial")
public static class Factory implements TProtocolFactory {
private final Class<?> ifaceClass;
private final boolean isServer;
public Factory() {
this(null, true);
}
public Factory(Class<?> ifaceClass) {
this(ifaceClass, true);
}
public Factory(Class<?> ifaceClass, boolean isServer) {
this.ifaceClass = ifaceClass;
this.isServer = isServer;
}
public TProtocol getProtocol(TTransport trans) {
return new TSimpleJSONProtocol(trans, ifaceClass, isServer);
}
}
private static final byte[] COMMA = new byte[] { ',' };
private static final byte[] COLON = new byte[] { ':' };
private static final byte[] LBRACE = new byte[] { '{' };
private static final byte[] RBRACE = new byte[] { '}' };
private static final byte[] LBRACKET = new byte[] { '[' };
private static final byte[] RBRACKET = new byte[] { ']' };
private static final char QUOTE = '"';
private static final TStruct ANONYMOUS_STRUCT = new TStruct();
private static final TField ANONYMOUS_FIELD = new TField();
// private static final TMessage EMPTY_MESSAGE = new TMessage();
// private static final TSet EMPTY_SET = new TSet();
// private static final TList EMPTY_LIST = new TList();
// private static final TMap EMPTY_MAP = new TMap();
private static final String LIST = "list";
private static final String SET = "set";
private static final String MAP = "map";
protected class Context {
protected void write() throws TException {
}
/**
* Returns whether the current value is a key in a map
*/
protected boolean isMapKey() {
return false;
}
}
protected class ListContext extends Context {
protected boolean first_ = true;
protected void write() throws TException {
if (first_) {
first_ = false;
} else {
trans_.write(COMMA);
}
}
}
protected class StructContext extends Context {
protected boolean first_ = true;
protected boolean colon_ = true;
protected void write() throws TException {
if (first_) {
first_ = false;
colon_ = true;
} else {
trans_.write(colon_ ? COLON : COMMA);
colon_ = !colon_;
}
}
}
protected class MapContext extends StructContext {
protected boolean isKey = true;
@Override
protected void write() throws TException {
super.write();
isKey = !isKey;
}
protected boolean isMapKey() {
// we want to coerce map keys to json strings regardless
// of their type
return isKey;
}
}
protected final Context BASE_CONTEXT = new Context();
/**
* Stack of nested contexts that we may be in.
*/
protected Stack<Context> writeContextStack_ = new Stack<Context>();
/**
* Current context that we are in
*/
protected Context writeContext_ = BASE_CONTEXT;
/**
* Push a new write context onto the stack.
*/
protected void pushWriteContext(Context c) {
writeContextStack_.push(writeContext_);
writeContext_ = c;
}
/**
* Pop the last write context off the stack
*/
protected void popWriteContext() {
writeContext_ = writeContextStack_.pop();
}
/**
* Used to make sure that we are not encountering a map whose keys are
* containers
*/
protected void assertContextIsNotMapKey(String invalidKeyType) throws CollectionMapKeyException {
if (writeContext_.isMapKey()) {
throw new CollectionMapKeyException("Cannot serialize a map with keys that are of type " + invalidKeyType);
}
}
private Class argsTBaseClass;
private final Class<?> ifaceClass;
private final boolean isServer;
/**
* Constructor
*/
public TSimpleJSONProtocol(TTransport trans) {
this(trans, null, true);
}
/**
* Constructor
*/
public TSimpleJSONProtocol(TTransport trans, Class<?> ifaceClass, boolean isServer) {
super(trans);
this.isServer = isServer;
this.ifaceClass = ifaceClass;
}
public Class getArgsTBaseClass() {
return argsTBaseClass;
}
public void setArgsTBaseClass(Class argsTBaseClass) {
this.argsTBaseClass = argsTBaseClass;
}
public void writeMessageBegin(TMessage message) throws TException {
trans_.write(LBRACKET);
pushWriteContext(new ListContext());
writeString(message.name);
writeByte(message.type);
writeI32(message.seqid);
}
public void writeMessageEnd() throws TException {
popWriteContext();
trans_.write(RBRACKET);
}
public void writeStructBegin(TStruct struct) throws TException {
writeContext_.write();
trans_.write(LBRACE);
pushWriteContext(new StructContext());
}
public void writeStructEnd() throws TException {
popWriteContext();
trans_.write(RBRACE);
}
public void writeFieldBegin(TField field) throws TException {
// Note that extra type information is omitted in JSON!
writeString(useFieldId ? String.valueOf(field.id) : field.name);
}
public void writeFieldEnd() {
}
public void writeFieldStop() {
}
public void writeMapBegin(TMap map) throws TException {
assertContextIsNotMapKey(MAP);
writeContext_.write();
trans_.write(LBRACE);
pushWriteContext(new MapContext());
// No metadata!
}
public void writeMapEnd() throws TException {
popWriteContext();
trans_.write(RBRACE);
}
public void writeListBegin(TList list) throws TException {
assertContextIsNotMapKey(LIST);
writeContext_.write();
trans_.write(LBRACKET);
pushWriteContext(new ListContext());
// No metadata!
}
public void writeListEnd() throws TException {
popWriteContext();
trans_.write(RBRACKET);
}
public void writeSetBegin(TSet set) throws TException {
assertContextIsNotMapKey(SET);
writeContext_.write();
trans_.write(LBRACKET);
pushWriteContext(new ListContext());
// No metadata!
}
public void writeSetEnd() throws TException {
popWriteContext();
trans_.write(RBRACKET);
}
public void writeBool(boolean b) throws TException {
writeByte(b ? (byte) 1 : (byte) 0);
}
public void writeByte(byte b) throws TException {
writeI32(b);
}
public void writeI16(short i16) throws TException {
writeI32(i16);
}
public void writeI32(int i32) throws TException {
if (writeContext_.isMapKey()) {
writeString(Integer.toString(i32));
} else {
writeContext_.write();
_writeStringData(Integer.toString(i32));
}
}
public void _writeStringData(String s) throws TException {
try {
byte[] b = s.getBytes("UTF-8");
trans_.write(b);
} catch (UnsupportedEncodingException uex) {
throw new TException("JVM DOES NOT SUPPORT UTF-8");
}
}
public void writeI64(long i64) throws TException {
if (writeContext_.isMapKey()) {
writeString(Long.toString(i64));
} else {
writeContext_.write();
_writeStringData(Long.toString(i64));
}
}
public void writeDouble(double dub) throws TException {
if (writeContext_.isMapKey()) {
writeString(Double.toString(dub));
} else {
writeContext_.write();
_writeStringData(Double.toString(dub));
}
}
public void writeString(String str) throws TException {
writeContext_.write();
int length = str.length();
StringBuffer escape = new StringBuffer(length + 16);
escape.append(QUOTE);
for (int i = 0; i < length; ++i) {
char c = str.charAt(i);
switch (c) {
case '"':
case '\\':
escape.append('\\');
escape.append(c);
break;
case '\b':
escape.append('\\');
escape.append('b');
break;
case '\f':
escape.append('\\');
escape.append('f');
break;
case '\n':
escape.append('\\');
escape.append('n');
break;
case '\r':
escape.append('\\');
escape.append('r');
break;
case '\t':
escape.append('\\');
escape.append('t');
break;
default:
// Control characters! According to JSON RFC u0020 (space)
if (c < ' ') {
String hex = Integer.toHexString(c);
escape.append('\\');
escape.append('u');
for (int j = 4; j > hex.length(); --j) {
escape.append('0');
}
escape.append(hex);
} else {
escape.append(c);
}
break;
}
}
escape.append(QUOTE);
_writeStringData(escape.toString());
}
public void writeBinary(ByteBuffer bin) throws TException {
try {
// TODO(mcslee): Fix this
writeString(new String(bin.array(), bin.position() + bin.arrayOffset(),
bin.limit() - bin.position() - bin.arrayOffset(), "UTF-8"));
} catch (UnsupportedEncodingException uex) {
throw new TException("JVM DOES NOT SUPPORT UTF-8");
}
}
private BaseArray msgStruct;
/**
* Reading methods.
*/
public TMessage readMessageBegin() throws TException {
byte[] buf = new byte[256];
ByteArrayOutputStream out = new ByteArrayOutputStream(1024);
while (true) {
int readLen = trans_.read(buf, 0, buf.length);
if (readLen == 0) {
break;
}
out.write(buf, 0, readLen);
if (readLen < buf.length) {
break;
}
}
String sb = null;
try {
buf = out.toByteArray();
sb = new String(buf, "UTF-8");
} catch (UnsupportedEncodingException e1) {
e1.printStackTrace();
}
// System.out.println("读取完毕: sb=" + sb);
// TODO JSON 格式的检查
if (sb.charAt(0) != '[' || sb.charAt(sb.length() - 1) != ']') {
throw new TProtocolException(TProtocolException.INVALID_DATA, "bad format!");
}
JSONArray jsonArray = new JSONArray(sb);
TMessage msg = new TMessage(jsonArray.getString(0), (byte) jsonArray.getInt(1), jsonArray.getInt(2));
// System.out.println(msg + ", jsonArray.len = " + jsonArray.length());
if (jsonArray.length() > 3) {
if (argsTBaseClass == null) {
try {
argsTBaseClass = guessTBaseClassByMethodName(msg.name);
} catch (Exception e) {
e.printStackTrace();
}
}
if (argsTBaseClass == null) {
// throw new
// TProtocolException(TApplicationException.UNKNOWN_METHOD,
// "Invalid method name: '" + msg.name + "'");
return new TMessage(msg.name, TMessageType.EXCEPTION, msg.seqid);
}
@SuppressWarnings("unchecked")
StructMetaData meta = new StructMetaData(TType.STRUCT, argsTBaseClass);
msgStruct = new BaseArray(meta, (ArrayJson) jsonArray.get(3));
}
return msg;
}
private static ConcurrentHashMap<String, Class<?>> tBaseclassCache = new ConcurrentHashMap<String, Class<?>>();
private Class guessTBaseClassByMethodName(String name) throws Exception {
String classSimpleName = String.format("%s_%s", name, isServer ? "args" : "result");
Class<?> result = tBaseclassCache.get(classSimpleName);
if (result != null) {
return result;
}
String className = String.format("%s$%s", ifaceClass.getEnclosingClass().getName(), classSimpleName);
if (ifaceClass != null) {
try {
result = Class.forName(className, false, ifaceClass.getClassLoader());
tBaseclassCache.putIfAbsent(classSimpleName, result);
return result;
} catch (Exception e) {
Class[] cls = ifaceClass.getInterfaces();
if (cls != null) {
for (Class c : cls) {
String cname = String.format("%s$%s", c.getEnclosingClass().getName(), classSimpleName);
try {
result = Class.forName(cname);
className = cname;
tBaseclassCache.putIfAbsent(classSimpleName, result);
return result;
} catch (Exception ex) {
}
}
}
}
}
java.lang.reflect.Field f = FieldMetaData.class.getDeclaredField("structMap");
f.setAccessible(true);
@SuppressWarnings("unchecked")
Map<Class<? extends TBase>, Map<? extends TFieldIdEnum, FieldMetaData>> structMap = (Map) f.get(null);
for (Class c : structMap.keySet()) {
if (c.getName().equals(className)) {
tBaseclassCache.putIfAbsent(classSimpleName, c);
return c;
}
}
return null;
}
public void readMessageEnd() {
}
private LinkedList<BaseArray> structStack = new LinkedList<BaseArray>();
public TField readFieldBegin() throws TException {
BaseArray prevStruct = structStack.peek();
TField field = prevStruct.newField();
return field != null ? field : ANONYMOUS_FIELD;
}
public void readFieldEnd() {
}
public TStruct readStructBegin() {
BaseArray prevStruct = structStack.peek();
if (prevStruct != null) {
BaseArray e = prevStruct.getArray();
structStack.push(e);
} else {
structStack.push(msgStruct);
}
return ANONYMOUS_STRUCT;
}
public TList readListBegin() throws TException {
BaseArray prevStruct = structStack.peek();
BaseArray obj = prevStruct.getArray();
structStack.push(obj);
ListMetaData lm = (ListMetaData) obj.getMetaData();
return new TList(lm.elemMetaData.type, obj.length());
}
public TSet readSetBegin() throws TException {
BaseArray prevStruct = structStack.peek();
BaseArray obj = prevStruct.getArray();
structStack.push(obj);
SetMetaData lm = (SetMetaData) obj.getMetaData();
return new TSet(lm.elemMetaData.type, obj.length());
}
public TMap readMapBegin() throws TException {
BaseArray prevStruct = structStack.peek();
BaseArray obj = prevStruct.getArray();
structStack.push(obj);
MapMetaData mm = (MapMetaData) obj.getMetaData();
return new TMap(mm.keyMetaData.type, mm.valueMetaData.type, obj.length());
}
private boolean useFieldId;
public void readStructEnd() {
BaseArray prevStruct = structStack.pop();
if (!useFieldId && prevStruct.useId()) {
useFieldId = true;
}
}
public void readListEnd() {
structStack.pop();
}
public void readMapEnd() {
structStack.pop();
}
public void readSetEnd() {
structStack.pop();
}
public int readI32() throws TException {
BaseArray prevStruct = structStack.peek();
return prevStruct.getInt();
}
public boolean readBool() throws TException {
BaseArray prevStruct = structStack.peek();
return prevStruct.getBoolean();
}
/**
* Read a single byte off the wire. Nothing interesting here.
*/
public byte readByte() throws TException {
return (byte) readI32();
}
public short readI16() throws TException {
return (short) readI32();
}
public long readI64() throws TException {
BaseArray prevStruct = structStack.peek();
return prevStruct.getLong();
}
public double readDouble() throws TException {
BaseArray prevStruct = structStack.peek();
return prevStruct.getDouble();
}
public String readString() throws TException {
BaseArray prevStruct = structStack.peek();
return prevStruct.getString();
}
// public String readStringBody(int size) throws TException {
// // TODO(mcslee): implement
// return "";
// }
public ByteBuffer readBinary() throws TException {
return ByteBuffer.wrap(readString().getBytes());
}
@SuppressWarnings("serial")
public static class CollectionMapKeyException extends TException {
public CollectionMapKeyException(String message) {
super(message);
}
}
}