package com.feedly.cassandra.entity.enhance;
import static org.objectweb.asm.Opcodes.*;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.IntInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.objectweb.asm.tree.VarInsnNode;
import com.feedly.cassandra.anno.Column;
import com.feedly.cassandra.anno.ColumnFamily;
import com.feedly.cassandra.anno.EmbeddedEntity;
import com.feedly.cassandra.anno.UnmappedColumnHandler;
import com.feedly.cassandra.entity.PropertyMetadataFactory;
public class EntityTransformer extends ClassTransformer
{
public EntityTransformer(ClassTransformer ct)
{
super(ct);
}
@Override
public boolean transform(ClassNode cn)
{
boolean rv = addInterface(cn);
if(rv)
{
addFields(cn);
implementInterface(cn);
modifyConstructor(cn);
modifyUnmappedHandler(cn);
try
{
/* int cnt = */ modifyAccessors(cn);
// System.out.println(cn.name + ": " + cnt + " properties detected");
}
catch(ClassNotFoundException cnfe)
{
throw new RuntimeException("entity classes must be valid.", cnfe);
}
}
// else
// System.out.println("skipping " + cn.name);
return rv || super.transform(cn);
}
private boolean addInterface(ClassNode cn)
{
String iface = Type.getInternalName(IEnhancedEntity.class);
if(cn.interfaces.contains(iface))
return false;
String annoType = Type.getDescriptor(ColumnFamily.class);
String annoType2 = Type.getDescriptor(EmbeddedEntity.class);
boolean hasAnno = false;
if(cn.visibleAnnotations != null)
{
for(AnnotationNode anno : cn.visibleAnnotations)
{
if(anno.desc.equals(annoType) || anno.desc.equals(annoType2))
{
hasAnno = true;
break;
}
}
}
if(!hasAnno)
return false;
cn.interfaces.add(iface);
return true;
}
private void implementInterface(ClassNode cn)
{
/*
* public BitSet getModifiedFields();
* public void setModifiedFields(BitSet b);
*
* public boolean getUnmappedFieldsModified();
* public void setUnmappedFieldsModified(boolean b);
*/
addAccessors(cn, "__modifiedFields", "ModifiedFields", BitSet.class);
addAccessors(cn, "__unmappedModified", "UnmappedFieldsModified", boolean.class);
}
private void addAccessors(ClassNode cn, String propName, String methodName, Class<?> type)
{
MethodNode mn = new MethodNode(ACC_PUBLIC, "get" + methodName, "()" + Type.getType(type).getDescriptor(), null, null);
mn.instructions.add(new VarInsnNode(ALOAD, 0));
mn.instructions.add(new FieldInsnNode(GETFIELD, cn.name, propName, Type.getDescriptor(type)));
mn.instructions.add(new InsnNode(type.isPrimitive() ? IRETURN : ARETURN));
mn.maxLocals = 1;
mn.maxStack = 1;
cn.methods.add(mn);
mn = new MethodNode(ACC_PUBLIC, "set" + methodName, "(" + Type.getType(type).getDescriptor() + ")" + Type.getType(void.class), null, null);
mn.instructions.add(new VarInsnNode(ALOAD, 0));
mn.instructions.add(new VarInsnNode(type.isPrimitive() ? ILOAD : ALOAD, 1));
mn.instructions.add(new FieldInsnNode(PUTFIELD, cn.name, propName, Type.getDescriptor(type)));
mn.instructions.add(new InsnNode(RETURN));
mn.maxLocals = 2;
mn.maxStack = 2;
cn.methods.add(mn);
}
private void addFields(ClassNode cn)
{
int acc = ACC_PRIVATE;
cn.fields.add(new FieldNode(acc, "__modifiedFields", Type.getDescriptor(BitSet.class), null, null));
cn.fields.add(new FieldNode(acc, "__unmappedModified", Type.getDescriptor(boolean.class), null, null));
}
private void modifyConstructor(ClassNode cn)
{
for(MethodNode mn : cn.methods)
{
if(!"<init>".equals(mn.name))
continue;
InsnList insns = mn.instructions;
if(insns.size() == 0)
continue;
Iterator<AbstractInsnNode> j = insns.iterator();
while(j.hasNext())
{
AbstractInsnNode in = j.next();
int op = in.getOpcode();
if(op == RETURN)
{
InsnList il = new InsnList();
il.add(new VarInsnNode(ALOAD, 0));
il.add(new TypeInsnNode(NEW, Type.getInternalName(BitSet.class)));
il.add(new InsnNode(DUP));
il.add(new MethodInsnNode(INVOKESPECIAL, Type.getInternalName(BitSet.class), "<init>", "()V"));
il.add(new FieldInsnNode(PUTFIELD, cn.name, "__modifiedFields", Type.getDescriptor(BitSet.class)));
insns.insert(in.getPrevious(), il);
}
}
mn.maxStack += 2;
}
}
private void modifyUnmappedHandler(ClassNode cn)
{
String desc = Type.getDescriptor(UnmappedColumnHandler.class);
String fieldName = null;
for(FieldNode field : cn.fields)
{
if(field.visibleAnnotations != null)
{
for(AnnotationNode anno : field.visibleAnnotations)
{
if(anno.desc.equals(desc))
{
fieldName = field.name;
break;
}
}
}
}
if(fieldName != null)
{
String setterName = accessorName(fieldName, true);
String getterName = accessorName(fieldName, false);
for(MethodNode mn : cn.methods)
{
if(mn.name.equals(setterName) || mn.name.equals(getterName))
{
insertUnmappedModBitInsns(cn, mn);
if(mn.name.equals(getterName))
mn.maxStack++;
}
}
}
}
private void insertUnmappedModBitInsns(ClassNode cn, MethodNode mn)
{
Iterator<AbstractInsnNode> j = mn.instructions.iterator();
while(j.hasNext())
{
AbstractInsnNode in = j.next();
int op = in.getOpcode();
if(op >= IRETURN && op <= RETURN)
{
InsnList il = new InsnList();
il.add(new VarInsnNode(ALOAD, 0));
il.add(new InsnNode(ICONST_1));
il.add(new FieldInsnNode(PUTFIELD, cn.name, "__unmappedModified", Type.getDescriptor(boolean.class)));
mn.instructions.insert(in.getPrevious(), il);
break;
}
}
mn.maxLocals++;
mn.maxStack++;
}
private int modifyAccessors(ClassNode cn) throws ClassNotFoundException
{
String desc = Type.getDescriptor(Column.class);
List<String> fieldNames = new ArrayList<String>();
for(FieldNode field : cn.fields)
{
if(field.visibleAnnotations != null)
{
for(AnnotationNode anno : field.visibleAnnotations)
{
if(anno.desc.equals(desc))
{
fieldNames.add(field.name);
}
}
}
}
Collections.sort(fieldNames);
for(FieldNode field : cn.fields)
{
if(fieldNames.contains(field.name))
{
for(AnnotationNode anno : field.visibleAnnotations)
{
if(anno.desc.equals(desc))
{
String setterName = accessorName(field.name, true);
String getterName = accessorName(field.name, false);
for(MethodNode mn : cn.methods)
{
boolean isSimple = false;
if(!isSimple)
isSimple = PropertyMetadataFactory.isPrimitiveType(Type.getType(field.desc).getClassName());
if(!isSimple)
isSimple = PropertyMetadataFactory.isSimpleType(Class.forName(Type.getType(field.desc).getClassName()));
if(mn.name.equals(setterName) || (mn.name.equals(getterName) && !isSimple))
{
insertModBitInsns(cn, mn, fieldNames.indexOf(field.name));
if(mn.name.equals(getterName))
mn.maxStack++;
}
}
}
}
}
}
return fieldNames.size();
}
private void insertModBitInsns(ClassNode cn, MethodNode mn, int bitPos)
{
Iterator<AbstractInsnNode> j = mn.instructions.iterator();
while(j.hasNext())
{
AbstractInsnNode in = j.next();
int op = in.getOpcode();
if(op >= IRETURN && op <= RETURN)
{
InsnList il = new InsnList();
il.add(new VarInsnNode(ALOAD, 0));
il.add(new FieldInsnNode(GETFIELD, cn.name, "__modifiedFields", Type.getDescriptor(BitSet.class)));
if(bitPos <= 5)
il.add(new InsnNode(ICONST_0 + bitPos));
else
il.add(new IntInsnNode(BIPUSH, bitPos));
il.add(new MethodInsnNode(INVOKEVIRTUAL, Type.getInternalName(BitSet.class), "set", "(I)V"));
mn.instructions.insert(in.getPrevious(), il);
break;
}
}
mn.maxLocals++;
mn.maxStack++;
}
private String accessorName(String fieldName, boolean setter)
{
StringBuilder b = new StringBuilder(setter ? "set" : "get");
b.append(Character.toUpperCase(fieldName.charAt(0)));
if(fieldName.length() > 1)
b.append(fieldName.substring(1));
return b.toString();
}
}