/* * Nocturne * Copyright (c) 2015-2016, Lapis <https://github.com/LapisBlue> * * The MIT License * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ package blue.lapis.nocturne.processor.transform; import static blue.lapis.nocturne.util.Constants.CLASS_FORMAT_CONSTANT_POOL_OFFSET; import static blue.lapis.nocturne.util.Constants.CLASS_PATH_SEPARATOR_CHAR; import static blue.lapis.nocturne.util.Constants.Processing.CLASS_PREFIX; import static blue.lapis.nocturne.util.helper.ByteHelper.asUshort; import static blue.lapis.nocturne.util.helper.ByteHelper.getBytes; import static blue.lapis.nocturne.util.helper.ByteHelper.readBytes; import static blue.lapis.nocturne.util.helper.StringHelper.getProcessedDescriptor; import static blue.lapis.nocturne.util.helper.StringHelper.getProcessedName; import static blue.lapis.nocturne.util.helper.StringHelper.getUnprocessedName; import blue.lapis.nocturne.Main; import blue.lapis.nocturne.processor.ClassProcessor; import blue.lapis.nocturne.processor.constantpool.model.ConstantPool; import blue.lapis.nocturne.processor.constantpool.model.ImmutableConstantPool; import blue.lapis.nocturne.processor.constantpool.model.structure.ClassStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.ConstantStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.FieldrefStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.IgnoredStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.MethodrefStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.NameAndTypeStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.RefStructure; import blue.lapis.nocturne.processor.constantpool.model.structure.StructureType; import blue.lapis.nocturne.processor.constantpool.model.structure.Utf8Structure; import blue.lapis.nocturne.processor.index.model.IndexedClass; import blue.lapis.nocturne.util.MemberType; import blue.lapis.nocturne.util.tuple.Pair; import com.google.common.collect.ImmutableList; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.IntStream; /** * Manages interpretation and transformation of constant pool, given the raw * bytecode of a class. */ public class ClassTransformer extends ClassProcessor { private ImmutableConstantPool constantPool; private boolean isPoolProcessed; private ConstantPool processedPool; private final List<String> syntheticFields = new ArrayList<>(); private final List<String> syntheticMethods = new ArrayList<>(); private final Map<Integer, Integer> processedFieldNameMap = new HashMap<>(); private final Map<Integer, Integer> processedFieldDescriptorMap = new HashMap<>(); private final Map<Integer, Integer> processedMethodNameMap = new HashMap<>(); private final Map<Integer, Integer> processedMethodDescriptorMap = new HashMap<>(); private static final ImmutableList<String> STUPID_PARAM_NAMES = ImmutableList.<String>builder() .add("\u2603").build(); private static final ImmutableList<String> IGNORED_METHODS = ImmutableList.<String>builder() .add("<init>").add("<clinit>").build(); public ClassTransformer(String className, byte[] bytes) { super(className, bytes); assert IndexedClass.INDEXED_CLASSES.containsKey(getClassName()); constantPool = IndexedClass.INDEXED_CLASSES.get(getClassName()).getConstantPool(); processedPool = new ConstantPool(constantPool.getContents(), constantPool.length()); } /** * Processes the class and returns the new bytecode. * * @return The processed bytecode */ public byte[] process() throws IOException { ByteBuffer buffer = ByteBuffer.wrap(bytes); byte[] header = processClassHeader(buffer); buffer.get(new byte[constantPool.length()]); // skip constant pool byte[] intermediate = processIntermediateBytes(buffer); byte[] fields = processFieldBytes(buffer); byte[] methods = processMethodBytes(buffer); byte[] remainder = processRemainder(buffer); // next call MUST come after field and method processing byte[] poolBytes = getProcessedPool().getBytes(); ByteBuffer bb = ByteBuffer.allocate(bytes.length + (poolBytes.length - constantPool.length())); bb.put(header); bb.put(poolBytes); bb.put(intermediate); bb.put(fields); bb.put(methods); bb.put(remainder); return bb.array(); } /** * Processes the header of the class (the first 8 bytes). * * @param buffer The {@link ByteBuffer} to read from * @return The processed class header */ public byte[] processClassHeader(ByteBuffer buffer) { byte[] value = new byte[CLASS_FORMAT_CONSTANT_POOL_OFFSET]; buffer.get(value); return value; } /** * Processes the intermediate bytes between the constant pool and the member * definitions. * * @param buffer The buffer to read from * @return The intermediate bytes */ public byte[] processIntermediateBytes(ByteBuffer buffer) { int initialPos = buffer.position(); // mark the initial position of the buffer // Okay, so here's what's happening here: // - The first 6 bytes aren't relevant at all // - The next 2 bytes are the number of interfaces the class implements // - The remaining bytes are pointers to class structures, one for each interface, each being 2 bytes // So, we need to process 8 bytes plus [2 times the interface count]. Hopefully this comment makes sense. final int irrelevantBytes = 6; // magic number buffer.get(new byte[irrelevantBytes]); // skip the header int interfaceCount = asUshort(buffer.getShort()); // read the interface count byte[] finalArray = new byte[irrelevantBytes + 2 + interfaceCount * 2]; // allocate an appropriately-sized array buffer.position(initialPos); // rewind the buffer to the initial position buffer.get(finalArray); // put the bytes into the allocated array return finalArray; } /** * Processes field definitions. * * @param buffer The buffer to read from * @return The new field definition bytes */ public byte[] processFieldBytes(ByteBuffer buffer) throws IOException { return processMemberBytes(buffer, false); } /** * Processes method definitions. * * @param buffer The buffer to read from * @return The new method definition bytes */ public byte[] processMethodBytes(ByteBuffer buffer) throws IOException { return processMemberBytes(buffer, true); } /** * Processes member definitions. * * @param buffer The buffer to read from * @param isMethod Whether the member is a method (a value of {@link false} * for this parameter is taken to mean the member is a field) * @return The new member definition bytes */ public byte[] processMemberBytes(ByteBuffer buffer, boolean isMethod) throws IOException { ByteArrayOutputStream os = new ByteArrayOutputStream(); int count = asUshort(buffer.getShort()); os.write(getBytes((short) count)); for (int m = 0; m < count; m++) { final int memberStart = buffer.position(); short access = buffer.getShort(); os.write(getBytes(access)); boolean isSynthetic = (access & 0x1000) != 0; buffer.get(new byte[4]); ByteArrayOutputStream attrOs = new ByteArrayOutputStream(); int attrCount = asUshort(buffer.getShort()); attrOs.write(getBytes((short) attrCount)); for (int i = 0; i < attrCount; i++) { Pair<byte[], Boolean> attr = processAttribute(buffer); attrOs.write(attr.first()); isSynthetic = attr.second(); } final byte[] attrArr = attrOs.toByteArray(); final int memberEnd = buffer.position(); buffer.position(memberStart + 2); int nameIndex = asUshort(buffer.getShort()); int descriptorIndex = asUshort(buffer.getShort()); if (isSynthetic) { (isMethod ? syntheticMethods : syntheticFields).add(getString(nameIndex)); } if (!isSynthetic) { if (!isMethod || !IGNORED_METHODS.contains(getString(nameIndex))) { Map<Integer, Integer> map = isMethod ? processedMethodNameMap : processedFieldNameMap; if (map.containsKey(nameIndex)) { nameIndex = map.get(nameIndex); } else { String procName = getProcessedName( getClassName() + CLASS_PATH_SEPARATOR_CHAR + getString(nameIndex), getString(descriptorIndex), isMethod ? MemberType.METHOD : MemberType.FIELD ); Utf8Structure nameStruct = new Utf8Structure(procName); processedPool.add(nameStruct); nameIndex = processedPool.size(); } } } os.write(getBytes((short) nameIndex)); Map<Integer, Integer> map = isMethod ? processedMethodDescriptorMap : processedFieldDescriptorMap; if (map.containsKey(descriptorIndex)) { descriptorIndex = map.get(descriptorIndex); } else { String procDesc = getProcessedDescriptor( isMethod ? MemberType.METHOD : MemberType.FIELD, getString(descriptorIndex) ); if (!procDesc.equals(getString(descriptorIndex))) { Utf8Structure descStruct = new Utf8Structure(procDesc); processedPool.add(descStruct); descriptorIndex = processedPool.size(); } } os.write(getBytes((short) descriptorIndex)); os.write(attrArr); buffer.position(memberEnd); } return os.toByteArray(); } /** * Returns any bytes remaining in the class file after the given offset. * * @param buffer The buffer to read * @return The remainder of the class file */ public byte[] processRemainder(ByteBuffer buffer) { return ByteBuffer.allocate(buffer.capacity() - buffer.position()).put(buffer).array(); } private ConstantPool getProcessedPool() { if (!isPoolProcessed) { IntStream.range(1, processedPool.size() + 1).forEach(this::handleMember); isPoolProcessed = true; } return processedPool; } private void handleMember(int index) { ConstantStructure cs = processedPool.get(index); if (!(cs instanceof IgnoredStructure)) { if (cs.getType() == StructureType.CLASS) { handleClassMember(cs, index, processedPool); } else if ( cs.getType() == StructureType.FIELDREF || cs.getType() == StructureType.INTERFACE_METHODREF || cs.getType() == StructureType.METHODREF) { handleNonClassMember(cs, index, processedPool); } } } private void handleClassMember(ConstantStructure cs, int index, ConstantPool pool) { String name = getString(((ClassStructure) cs).getNameIndex()); if (!Main.getLoadedJar().getClass(name).isPresent()) { return; } String newName = getProcessedName(name, null, MemberType.CLASS); byte[] strBytes = newName.getBytes(StandardCharsets.UTF_8); ByteBuffer strBuffer = ByteBuffer.allocate(strBytes.length + 3); strBuffer.put(StructureType.UTF_8.getTag()); strBuffer.putShort((short) strBytes.length); strBuffer.put(strBytes); pool.add(new Utf8Structure(strBuffer.array())); ByteBuffer classBuffer = ByteBuffer.allocate(StructureType.CLASS.getLength() + 1); classBuffer.put(StructureType.CLASS.getTag()); classBuffer.putShort((short) pool.size()); pool.set(index, new ClassStructure(classBuffer.array())); } private void handleNonClassMember(ConstantStructure cs, int index, ConstantPool pool) { MemberType memberType; switch (cs.getType()) { case FIELDREF: { memberType = MemberType.FIELD; break; } case INTERFACE_METHODREF: // fall through case METHODREF: { memberType = MemberType.METHOD; break; } default: { throw new AssertionError(); } } String className = getClassNameFromStruct((RefStructure) cs); if (className.startsWith(CLASS_PREFIX)) { className = getUnprocessedName(className); } NameAndType nat = getNameAndType((RefStructure) cs); int natIndex = ((RefStructure) cs).getNameAndTypeIndex(); NameAndTypeStructure natStruct = (NameAndTypeStructure) constantPool.get(natIndex); int nameIndex = natStruct.getNameIndex(); int typeIndex = natStruct.getTypeIndex(); boolean ignored = false; if (IGNORED_METHODS.contains(nat.getName())) { // don't process ignored methods ignored = true; } String desc = nat.getType(); boolean isSynthetic = (memberType == MemberType.FIELD ? syntheticFields : syntheticMethods).contains(nat.getName()); if (Main.getLoadedJar().getClass(className).isPresent() && !isSynthetic && !ignored) { String newName = getProcessedName(className + CLASS_PATH_SEPARATOR_CHAR + nat.getName(), desc, memberType); byte[] newNameBytes = newName.getBytes(StandardCharsets.UTF_8); ByteBuffer nameBuffer = ByteBuffer.allocate(newNameBytes.length + 3); nameBuffer.put(StructureType.UTF_8.getTag()); nameBuffer.putShort((short) newNameBytes.length); nameBuffer.put(newNameBytes); pool.add(new Utf8Structure(nameBuffer.array())); Map<Integer, Integer> map = memberType == MemberType.FIELD ? processedFieldNameMap : processedMethodNameMap; map.put(nameIndex, pool.size()); nameIndex = pool.size(); } String processedDesc = getProcessedDescriptor( cs.getType() == StructureType.FIELDREF ? MemberType.FIELD : MemberType.METHOD, desc ); if (!processedDesc.equals(desc)) { byte[] newTypeBytes = processedDesc.getBytes(StandardCharsets.UTF_8); ByteBuffer typeBuffer = ByteBuffer.allocate(newTypeBytes.length + 3); typeBuffer.put(StructureType.UTF_8.getTag()); typeBuffer.putShort((short) newTypeBytes.length); typeBuffer.put(newTypeBytes); pool.add(new Utf8Structure(typeBuffer.array())); Map<Integer, Integer> map = memberType == MemberType.FIELD ? processedFieldDescriptorMap : processedMethodDescriptorMap; map.put(typeIndex, pool.size()); typeIndex = pool.size(); } ByteBuffer buffer = ByteBuffer.allocate(StructureType.NAME_AND_TYPE.getLength() + 1); buffer.put(StructureType.NAME_AND_TYPE.getTag()); buffer.putShort((short) nameIndex); buffer.putShort((short) typeIndex); pool.add(new NameAndTypeStructure(buffer.array())); StructureType st = memberType == MemberType.FIELD ? StructureType.FIELDREF : StructureType.METHODREF; ByteBuffer mBuffer = ByteBuffer.allocate(st.getLength() + 1); mBuffer.put(st.getTag()); mBuffer.putShort((short) ((RefStructure) cs).getClassIndex()); mBuffer.putShort((short) pool.size()); pool.set(index, memberType == MemberType.FIELD ? new FieldrefStructure(mBuffer.array()) : new MethodrefStructure(mBuffer.array())); } @SuppressWarnings("fallthrough") private Pair<byte[], Boolean> processAttribute(ByteBuffer buffer) throws IOException { ByteArrayOutputStream os = new ByteArrayOutputStream(); boolean isSynthetic = false; int attrNameIndex = asUshort(buffer.getShort()); os.write(getBytes((short) attrNameIndex)); String attrName = getString(attrNameIndex); int attrLength = buffer.getInt(); switch (attrName) { case "Code": { // note: we're now at max_stack ByteArrayOutputStream bufferOs = new ByteArrayOutputStream(); bufferOs.write(readBytes(buffer, 4)); // skip max_stack and max_locals // skip the actual code (also unimportant to us) int codeLength = buffer.getInt(); // read code_length bufferOs.write(getBytes(codeLength)); bufferOs.write(readBytes(buffer, codeLength)); // read code // skip the exception table int exceptionTableLength = asUshort(buffer.getShort()); // read exception_table_length bufferOs.write(getBytes((short) exceptionTableLength)); bufferOs.write(readBytes(buffer, exceptionTableLength * 8)); // exception_table (each entry is 8 bytes) // now we get to the good stuff // note: we're now at attribute_count ByteArrayOutputStream subOs = new ByteArrayOutputStream(); // since the length can change int attrCount = asUshort(buffer.getShort()); // read attributes_count int actualAttrCount = 0; for (int a = 0; a < attrCount; a++) { // now we're in a sub-attribute int subAttrNameIndex = asUshort(buffer.getShort()); // read attribute_name_index String subAttrName = getString(subAttrNameIndex); int subAttrLength = buffer.getInt(); // read attribute_length if (subAttrName.equals("LocalVariableTable")) { attrLength -= subAttrLength + 6; readBytes(buffer, subAttrLength); // read and discard attribute body } else { actualAttrCount++; subOs.write(getBytes((short) subAttrNameIndex)); // write attribute_name_index subOs.write(getBytes(subAttrLength)); // write attribute_length subOs.write(readBytes(buffer, subAttrLength)); // read and write attribute body } } bufferOs.write(getBytes((short) actualAttrCount)); bufferOs.write(subOs.toByteArray()); os.write(getBytes(attrLength)); os.write(bufferOs.toByteArray()); break; } case "Synthetic": { isSynthetic = true; } default: { os.write(getBytes(attrLength)); os.write(readBytes(buffer, attrLength)); break; } } return new Pair<>(os.toByteArray(), isSynthetic); } private NameAndType getNameAndType(RefStructure rs) { int natStructIndex = rs.getNameAndTypeIndex(); assert natStructIndex <= constantPool.size(); ConstantStructure natStruct = constantPool.get(natStructIndex); assert natStruct instanceof NameAndTypeStructure; int nameIndex = ((NameAndTypeStructure) natStruct).getNameIndex(); int typeIndex = ((NameAndTypeStructure) natStruct).getTypeIndex(); return new NameAndType(getString(nameIndex), getString(typeIndex)); } private String getString(int strIndex) { assert strIndex <= processedPool.size(); ConstantStructure cs = processedPool.get(strIndex); assert cs instanceof Utf8Structure; return ((Utf8Structure) cs).asString(); } private String getClassNameFromStruct(RefStructure rs) { int classIndex = rs.getClassIndex(); ConstantStructure classStruct = processedPool.get(classIndex); assert classStruct instanceof ClassStructure; return getString(((ClassStructure) classStruct).getNameIndex()); } private class NameAndType { private final String name; private final String type; NameAndType(String name, String type) { this.name = name; this.type = type; } public String getName() { return name; } public String getType() { return type; } } }