/* * Copyright 2002-2007 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.baidu.bjf.remoting.protobuf.code; import java.io.File; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import com.baidu.bjf.remoting.protobuf.Codec; import com.baidu.bjf.remoting.protobuf.FieldType; import com.baidu.bjf.remoting.protobuf.utils.ClassHelper; import com.baidu.bjf.remoting.protobuf.utils.FieldInfo; import com.baidu.bjf.remoting.protobuf.utils.FieldUtils; import com.baidu.bjf.remoting.protobuf.utils.StringUtils; import com.google.protobuf.Descriptors.Descriptor; /** * Code generator utility class. * * @author xiemalin * @since 1.0.0 */ public class CodeGenerator { /** auto proxied suffix class name. */ public static final String DEFAULT_SUFFIX_CLASSNAME = "$$JProtoBufClass"; /** The Constant JAVA_CLASS_FILE_SUFFIX. */ public static final String JAVA_CLASS_FILE_SUFFIX = ".class"; /** Logger for this class. */ private static final Logger LOGGER = Logger.getLogger(CodeGenerator.class.getName()); /** target fields which marked <code> @Protofuf </code> annotation. */ private List<FieldInfo> fields; /** enable debug mode. */ private boolean debug = false; /** * static path for output dynamic compiled class file. */ private File outputPath; /** The relative proxy classes. */ private Set<Class<?>> relativeProxyClasses = new HashSet<Class<?>>(); /** * Gets the relative proxy classes. * * @return the relative proxy classes */ public Set<Class<?>> getRelativeProxyClasses() { return relativeProxyClasses; } /** * Sets the static path for output dynamic compiled class file. * * @param outputPath the new static path for output dynamic compiled class file */ public void setOutputPath(File outputPath) { this.outputPath = outputPath; } /** * Checks if is enable debug mode. * * @return the enable debug mode */ public boolean isDebug() { return debug; } /** * Sets the enable debug mode. * * @param debug the new enable debug mode */ public void setDebug(boolean debug) { this.debug = debug; } /** target class to parse <code>@Protobuf</code> annotation to generate code. */ private Class<?> cls; /** * Constructor method. * * @param fields protobuf mapped fields * @param cls protobuf mapped class */ public CodeGenerator(List<FieldInfo> fields, Class<?> cls) { super(); this.fields = fields; this.cls = cls; } /** * Gets the class name. * * @return the class name */ public String getClassName() { return ClassHelper.getClassName(cls) + DEFAULT_SUFFIX_CLASSNAME; } /** * Gets the package. * * @return the package */ public String getPackage() { return ClassHelper.getPackage(cls); } /** * Gets the full class name. * * @return the full class name */ public String getFullClassName() { if (StringUtils.isEmpty(getPackage())) { return getClassName(); } return getPackage() + ClassHelper.PACKAGE_SEPARATOR + getClassName(); } /** * Gets the code. * * @return the code */ public String getCode() { String className = getClassName(); ClassCode code = new ClassCode(ClassCode.SCOPE_PUBLIC, className); // to implements Codec interface code.addInteface(Codec.class.getName() + "<" + ClassHelper.getInternalName(cls.getName()) + ">"); // package code.setPkg(getPackage()); // import classes genImportCode(code); // define Descriptor field String descriptorClsName = ClassHelper.getInternalName(Descriptor.class.getName()); code.addField(ClassCode.SCOPE_DEFAULT, descriptorClsName, "descriptor", null); // define class code.addMethod(getEncodeMethodCode()); code.addMethod(getDecodeMethodCode()); code.addMethod(getSizeMethodCode()); code.addMethod(getWriteToMethodCode()); code.addMethod(getReadFromMethodCode()); code.addMethod(getGetDescriptorMethodCode()); return code.toCode(); } /** * generate import code. * * @param code the code */ private void genImportCode(ClassCode code) { code.importClass("java.util.*"); code.importClass("java.io.IOException"); code.importClass("java.lang.reflect.*"); code.importClass("com.baidu.bjf.remoting.protobuf.code.*"); code.importClass("com.baidu.bjf.remoting.protobuf.utils.*"); code.importClass("com.baidu.bjf.remoting.protobuf.*"); code.importClass("com.google.protobuf.*"); if (!StringUtils.isEmpty(getPackage())) { code.importClass(ClassHelper.getInternalName(cls.getName())); } } /** * To generate parse google protocol buffer byte array parser code. * * @param code add new generated code to the builder. * @return the parses the bytes method code */ private void getParseBytesMethodCode(MethodCode mc) { StringBuilder code = new StringBuilder(); // define return code.append(ClassHelper.getInternalName(cls.getName())).append(" ret = new "); code.append(ClassHelper.getInternalName(cls.getName())).append("()"); mc.appendLineCode1(code.toString()); code.setLength(0); // 执行初始化,主要针对枚举类型 for (FieldInfo field : fields) { boolean isList = field.isList(); if (field.getFieldType() == FieldType.ENUM) { String clsName = ClassHelper.getInternalName(field.getField().getType().getName()); if (!isList) { String express = "java.lang.Enum.valueOf(" + clsName + ".class, " + clsName + ".values()[0].name())"; // add set get method mc.appendLineCode1(getSetToField("ret", field.getField(), cls, express, isList, field.isMap())); } } } // add parse method code here mc.appendLineCode0("try {"); mc.appendLineCode1(ClassCode.CODE_FORMAT + "boolean done = false"); mc.appendLineCode1(ClassCode.CODE_FORMAT + "Codec codec = null"); mc.appendLineCode0(ClassCode.CODE_FORMAT + "while (!done) {"); mc.appendLineCode1(ClassCode.CODE_FORMAT + "int tag = input.readTag()"); mc.appendLineCode0(ClassCode.CODE_FORMAT + "if (tag == 0) { break;}"); for (FieldInfo field : fields) { boolean isList = field.isList(); if (field.getFieldType() != FieldType.DEFAULT) { code.append("if (tag == ").append(CodedConstant.makeTag(field.getOrder(), field.getFieldType().getInternalFieldType().getWireType())); code.append(") {"); mc.appendLineCode0(code.toString()); } else { code.append("if (tag == CodedConstant.makeTag(").append(field.getOrder()); code.append(",WireFormat.").append(field.getFieldType().getWireFormat()).append(")) {"); mc.appendLineCode0(code.toString()); } code.setLength(0); String t = field.getFieldType().getType(); t = CodedConstant.capitalize(t); boolean listTypeCheck = false; String express; // enumeration type if (field.getFieldType() == FieldType.ENUM) { String clsName = ClassHelper.getInternalName(field.getField().getType().getName()); if (isList) { if (field.getGenericKeyType() != null) { Class cls = field.getGenericKeyType(); clsName = ClassHelper.getInternalName(cls.getName()); } } express = "java.lang.Enum.valueOf(" + clsName + ".class, CodedConstant.getEnumName(" + clsName + ".values()," + "input.read" + t + "()))"; } else { express = "input.read" + t + "()"; } // if List type and element is object message type if (isList && field.getFieldType() == FieldType.OBJECT) { if (field.getGenericKeyType() != null) { Class cls = field.getGenericKeyType(); String name = ClassHelper.getInternalName(cls.getName()); // need // to // parse // nested // class code.append("codec = ProtobufProxy.create(").append(name).append(".class"); if (debug) { code.append(", true"); } else { code.append(", false"); } String spath = "ProtobufProxy.OUTPUT_PATH.get()"; code.append(",").append(spath); code.append(")"); mc.appendLineCode1(code.toString()); code.setLength(0); mc.appendLineCode1("int length = input.readRawVarint32()"); mc.appendLineCode1("final int oldLimit = input.pushLimit(length)"); listTypeCheck = true; express = "(" + name + ") codec.readFrom(input)"; } } else if (field.getFieldType() == FieldType.OBJECT) { // if object // message // type Class cls = field.getField().getType(); String name = ClassHelper.getInternalName(cls.getName()); // need // to // parse // nested // class code.append("codec = ProtobufProxy.create(").append(name).append(".class"); if (debug) { code.append(", true"); } else { code.append(", false"); } String spath = "ProtobufProxy.OUTPUT_PATH.get()"; code.append(",").append(spath); code.append(")"); mc.appendLineCode1(code.toString()); code.setLength(0); mc.appendLineCode1("int length = input.readRawVarint32()"); mc.appendLineCode1("final int oldLimit = input.pushLimit(length)"); listTypeCheck = true; express = "(" + name + ") codec.readFrom(input)"; } if (field.getFieldType() == FieldType.BYTES) { express += ".toByteArray()"; } mc.appendLineCode1(getSetToField("ret", field.getField(), cls, express, isList, field.isMap())); if (listTypeCheck) { mc.appendLineCode1("input.checkLastTagWas(0)"); mc.appendLineCode1("input.popLimit(oldLimit)"); } mc.appendLineCode1("continue"); mc.appendLineCode0("}"); } mc.appendLineCode1("input.skipField(tag)"); mc.appendLineCode0("}"); mc.appendLineCode0("} catch (com.google.protobuf.InvalidProtocolBufferException e) {"); mc.appendLineCode1("throw e"); mc.appendLineCode0("} catch (java.io.IOException e) {"); mc.appendLineCode1("throw e"); mc.appendLineCode0("}"); for (FieldInfo field : fields) { if (field.isRequired()) { mc.appendLineCode0(CodedConstant.getRetRequiredCheck(getAccessByField("ret", field.getField(), cls), field.getField())); } } mc.appendLineCode1("return ret"); } /** * Gets the decode method code. * * @return the decode method code */ private MethodCode getDecodeMethodCode() { MethodCode mc = new MethodCode(); mc.setName("decode"); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.setReturnType(ClassHelper.getInternalName(cls.getName())); mc.addParameter("byte[]", "bb"); mc.addException("IOException"); // add method code mc.appendLineCode1("CodedInputStream input = CodedInputStream.newInstance(bb, 0, bb.length)"); getParseBytesMethodCode(mc); return mc; } /** * Gets the gets the descriptor method code. * * @return the gets the descriptor method code */ private MethodCode getGetDescriptorMethodCode() { String descriptorClsName = ClassHelper.getInternalName(Descriptor.class.getName()); MethodCode mc = new MethodCode(); mc.setName("getDescriptor"); mc.setReturnType(descriptorClsName); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.addException("IOException"); String methodSource = CodeTemplate.descriptorMethodSource(cls); mc.appendLineCode0(methodSource); return mc; } /** * Gets the read from method code. * * @return the read from method code */ private MethodCode getReadFromMethodCode() { MethodCode mc = new MethodCode(); mc.setName("readFrom"); mc.setReturnType(ClassHelper.getInternalName(cls.getName())); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.addParameter("CodedInputStream", "input"); mc.addException("IOException"); getParseBytesMethodCode(mc); return mc; } /** * Check {@link FieldType} is validate to class type of {@link Field}. * * @param type the type * @param field the field */ private void checkType(FieldType type, Field field) { Class<?> cls = field.getType(); if (type == FieldType.OBJECT || type == FieldType.ENUM) { return; } String javaType = type.getJavaType(); if (Integer.class.getSimpleName().equals(javaType)) { if (cls.getSimpleName().equals("int") || Integer.class.getSimpleName().equals(cls.getSimpleName())) { return; } throw new IllegalArgumentException(getMismatchTypeErroMessage(type, field)); } if (!javaType.equalsIgnoreCase(cls.getSimpleName())) { throw new IllegalArgumentException(getMismatchTypeErroMessage(type, field)); } } /** * get error message info by type not matched. * * @param type the type * @param field the field * @return error message for mismatch type */ private String getMismatchTypeErroMessage(FieldType type, Field field) { return "Type mismatch. @Protobuf required type '" + type.getJavaType() + "' but field type is '" + field.getType().getSimpleName() + "' of field name '" + field.getName() + "' on class " + field.getDeclaringClass().getName(); } /** * Gets the encode method code. * * @return the encode method code */ private MethodCode getEncodeMethodCode() { MethodCode mc = new MethodCode(); mc.setName("encode"); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.setReturnType("byte[]"); mc.addParameter(ClassHelper.getInternalName(cls.getName()), "t"); mc.addException("IOException"); // add method code mc.appendLineCode1("int size = 0"); Set<Integer> orders = new HashSet<Integer>(); // encode method for (FieldInfo field : fields) { boolean isList = field.isList(); // check type if (!isList) { checkType(field.getFieldType(), field.getField()); } if (orders.contains(field.getOrder())) { throw new IllegalArgumentException("Field order '" + field.getOrder() + "' on field" + field.getField().getName() + " already exsit."); } // define field String checkParameterLine = CodedConstant.getMappedTypeDefined(field.getOrder(), field.getFieldType(), getAccessByField("t", field.getField(), cls), isList); mc.appendLineCode0(checkParameterLine); // compute size StringBuilder code = new StringBuilder(); code.append("if (!CodedConstant.isNull(").append(getAccessByField("t", field.getField(), cls)).append("))") .append("{"); mc.appendLineCode0(code.toString()); code.setLength(0); // clear old code code.append("size += "); code.append(CodedConstant.getMappedTypeSize(field, field.getOrder(), field.getFieldType(), isList, debug, outputPath)); mc.appendLineCode0(code.toString()); mc.appendLineCode0("}"); if (field.isRequired()) { mc.appendLineCode0(CodedConstant.getRequiredCheck(field.getOrder(), field.getField())); } } mc.appendLineCode1("final byte[] result = new byte[size]"); mc.appendLineCode1("final CodedOutputStream output = CodedOutputStream.newInstance(result)"); // call writeTo method mc.appendLineCode1("writeTo(t, output)"); mc.appendLineCode1("return result"); return mc; } /** * Gets the write to method code. * * @return the write to method code */ private MethodCode getWriteToMethodCode() { MethodCode mc = new MethodCode(); mc.setName("writeTo"); mc.setReturnType("void"); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.addParameter(ClassHelper.getInternalName(cls.getName()), "t"); mc.addParameter("CodedOutputStream", "output"); mc.addException("IOException"); Set<Integer> orders = new HashSet<Integer>(); for (FieldInfo field : fields) { boolean isList = field.isList(); // check type if (!isList) { checkType(field.getFieldType(), field.getField()); } if (orders.contains(field.getOrder())) { throw new IllegalArgumentException("Field order '" + field.getOrder() + "' on field" + field.getField().getName() + " already exsit."); } // define field mc.appendLineCode0(CodedConstant.getMappedTypeDefined(field.getOrder(), field.getFieldType(), getAccessByField("t", field.getField(), cls), isList)); if (field.isRequired()) { mc.appendLineCode0(CodedConstant.getRequiredCheck(field.getOrder(), field.getField())); } } for (FieldInfo field : fields) { boolean isList = field.isList(); // set write to byte mc.appendLineCode0( CodedConstant.getMappedWriteCode(field, "output", field.getOrder(), field.getFieldType(), isList)); } return mc; } /** * Gets the size method code. * * @return the size method code */ private MethodCode getSizeMethodCode() { MethodCode mc = new MethodCode(); mc.setName("size"); mc.setScope(ClassCode.SCOPE_PUBLIC); mc.setReturnType("int"); mc.addParameter(ClassHelper.getInternalName(cls.getName()), "t"); mc.addException("IOException"); // add method code mc.appendLineCode1("int size = 0"); Set<Integer> orders = new HashSet<Integer>(); // encode method for (FieldInfo field : fields) { boolean isList = field.isList(); // check type if (!isList) { checkType(field.getFieldType(), field.getField()); } if (orders.contains(field.getOrder())) { throw new IllegalArgumentException("Field order '" + field.getOrder() + "' on field" + field.getField().getName() + " already exsit."); } // define field mc.appendLineCode0(CodedConstant.getMappedTypeDefined(field.getOrder(), field.getFieldType(), getAccessByField("t", field.getField(), cls), isList)); // compute size StringBuilder code = new StringBuilder(); code.append("if (!CodedConstant.isNull(").append(getAccessByField("t", field.getField(), cls)).append("))") .append("{"); mc.appendLineCode0(code.toString()); code.setLength(0); code.append("size += "); code.append(CodedConstant.getMappedTypeSize(field, field.getOrder(), field.getFieldType(), isList, debug, outputPath)); code.append("}"); mc.appendLineCode0(code.toString()); if (field.isRequired()) { mc.appendLineCode0(CodedConstant.getRequiredCheck(field.getOrder(), field.getField())); } } mc.appendLineCode1("return size"); return mc; } /** * get field access code. * * @param target target instance name * @param field java field instance * @param cls mapped class * @return full field access java code */ protected String getAccessByField(String target, Field field, Class<?> cls) { if (field.getModifiers() == Modifier.PUBLIC) { return target + ClassHelper.PACKAGE_SEPARATOR + field.getName(); } // check if has getter method String getter; if ("boolean".equalsIgnoreCase(field.getType().getName())) { getter = "is" + CodedConstant.capitalize(field.getName()); } else { getter = "get" + CodedConstant.capitalize(field.getName()); } // check method exist try { cls.getMethod(getter, new Class<?>[0]); return target + ClassHelper.PACKAGE_SEPARATOR + getter + "()"; } catch (Exception e) { LOGGER.log(Level.FINE, e.getMessage(), e); } String type = field.getType().getName(); if ("[B".equals(type) || "[Ljava.lang.Byte;".equals(type)) { type = "byte[]"; } // use reflection to get value String code = "(" + FieldUtils.toObjectType(type) + ") "; code += "FieldUtils.getField(" + target + ", \"" + field.getName() + "\")"; return code; } /** * generate access {@link Field} value source code. support public field access, getter method access and reflection * access. * * @param target the target * @param field the field * @param cls the cls * @param express the express * @param isList the is list * @param isMap the is map * @return the sets the to field */ protected String getSetToField(String target, Field field, Class<?> cls, String express, boolean isList, boolean isMap) { StringBuilder ret = new StringBuilder(); if (isList || isMap) { ret.append("if ((").append(getAccessByField(target, field, cls)).append(") == null) {") .append(ClassCode.LINE_BREAK); } // if field of public modifier we can access directly if (Modifier.isPublic(field.getModifiers())) { if (isList) { // should initialize list ret.append(target).append(ClassHelper.PACKAGE_SEPARATOR).append(field.getName()) .append("= new ArrayList()").append(ClassCode.JAVA_LINE_BREAK).append("}") .append(ClassCode.LINE_BREAK); if (express != null) { ret.append(target).append(ClassHelper.PACKAGE_SEPARATOR).append(field.getName()).append(".add(") .append(express).append(")"); } return ret.toString(); } else if (isMap) { ret.append(target).append(ClassHelper.PACKAGE_SEPARATOR).append(field.getName()) .append("= new HashMap()").append(ClassCode.JAVA_LINE_BREAK).append("}") .append(ClassCode.LINE_BREAK); return ret.append(express).toString(); } return target + ClassHelper.PACKAGE_SEPARATOR + field.getName() + "=" + express + ClassCode.LINE_BREAK; } String setter = "set" + CodedConstant.capitalize(field.getName()); // check method exist try { cls.getMethod(setter, new Class<?>[] { field.getType() }); if (isList) { ret.append("List __list = new ArrayList()").append(ClassCode.JAVA_LINE_BREAK); ret.append(target).append(ClassHelper.PACKAGE_SEPARATOR).append(setter).append("(__list)") .append(ClassCode.JAVA_LINE_BREAK).append("}").append(ClassCode.LINE_BREAK); if (express != null) { ret.append("(").append(getAccessByField(target, field, cls)).append(").add(").append(express) .append(")"); } return ret.toString(); } else if (isMap) { ret.append("Map __map = new HashMap()").append(ClassCode.JAVA_LINE_BREAK); ret.append(target).append(ClassHelper.PACKAGE_SEPARATOR).append(setter).append("(__map)") .append(ClassCode.JAVA_LINE_BREAK).append("}").append(ClassCode.LINE_BREAK); return ret + express; } return target + ClassHelper.PACKAGE_SEPARATOR + setter + "(" + express + ")\n"; } catch (Exception e) { LOGGER.log(Level.FINE, e.getMessage(), e); } if (isList) { ret.append("List __list = new ArrayList()").append(ClassCode.JAVA_LINE_BREAK); ret.append("FieldUtils.setField(").append(target).append(", \"").append(field.getName()) .append("\", __list)").append(ClassCode.JAVA_LINE_BREAK).append("}").append(ClassCode.LINE_BREAK); if (express != null) { ret.append("(").append(getAccessByField(target, field, cls)).append(").add(").append(express) .append(")"); } return ret.toString(); } else if (isMap) { ret.append("Map __map = new HashMap()").append(ClassCode.JAVA_LINE_BREAK); ret.append("FieldUtils.setField(").append(target).append(", \"").append(field.getName()) .append("\", __map)").append(ClassCode.JAVA_LINE_BREAK).append("}").append(ClassCode.LINE_BREAK); return ret + express; } // use reflection to get value String code = ""; if (express != null) { code = "FieldUtils.setField(" + target + ", \"" + field.getName() + "\", " + express + ")" + ClassCode.LINE_BREAK; } return code; } }