/*
* Copyright (c) 2016, Kasra Faghihi, All rights reserved.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library.
*/
package com.offbynull.coroutines.instrumenter.generators;
import com.offbynull.coroutines.instrumenter.asm.VariableTable;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.call;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.construct;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.forEach;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.ifIntegersEqual;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.ifObjectsEqual;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.loadVar;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.loadStringConst;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.merge;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.returnValue;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.saveVar;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.tableSwitch;
import static com.offbynull.coroutines.instrumenter.asm.SearchUtils.findMethodsWithName;
import static com.offbynull.coroutines.instrumenter.testhelpers.TestUtils.readZipResourcesAsClassNodes;
import com.offbynull.coroutines.instrumenter.asm.VariableTable.Variable;
import java.lang.reflect.InvocationTargetException;
import java.net.URLClassLoader;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.junit.Before;
import org.junit.Test;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static com.offbynull.coroutines.instrumenter.generators.GenericGenerators.throwRuntimeException;
import static com.offbynull.coroutines.instrumenter.testhelpers.TestUtils.createJarAndLoad;
public final class GenericGeneratorsTest {
private static final String STUB_CLASSNAME = "SimpleStub";
private static final String STUB_FILENAME = STUB_CLASSNAME + ".class";
private static final String ZIP_RESOURCE_PATH = STUB_CLASSNAME + ".zip";
private static final String STUB_METHOD_NAME = "fillMeIn";
private ClassNode classNode;
private MethodNode methodNode;
@Before
public void setUp() throws Exception {
// Load class, get method
classNode = readZipResourcesAsClassNodes(ZIP_RESOURCE_PATH).get(STUB_FILENAME);
methodNode = findMethodsWithName(classNode.methods, STUB_METHOD_NAME).get(0);
}
@Test
public void mustCreateAndRunNestedSwitchStatements() throws Exception {
// Augment signature
methodNode.desc = Type.getMethodDescriptor(Type.getType(String.class), new Type[] { Type.INT_TYPE, Type.INT_TYPE });
// Initialize variable table
VariableTable varTable = new VariableTable(classNode, methodNode);
Variable intVar1 = varTable.getArgument(1);
Variable intVar2 = varTable.getArgument(2);
// Update method logic
/**
* switch(arg1) {
* case 0:
* throw new RuntimeException("0");
* case 1:
* throw new RuntimeException("1");
* case 2:
* switch(arg2) {
* case 0:
* throw new RuntimeException("0");
* case 1:
* throw new RuntimeException("1");
* case 2:
* return "OK!";
* default:
* throw new RuntimeException("innerdefault")
* }
* default:
* throw new RuntimeException("default");
* }
*/
methodNode.instructions
= tableSwitch(loadVar(intVar1),
throwRuntimeException("default"),
0,
throwRuntimeException("0"),
throwRuntimeException("1"),
tableSwitch(loadVar(intVar2),
throwRuntimeException("innerdefault"),
0,
throwRuntimeException("inner0"),
throwRuntimeException("inner1"),
GenericGenerators.returnValue(Type.getType(String.class), loadStringConst("OK!"))
)
);
// Write to JAR file + load up in classloader -- then execute tests
try (URLClassLoader cl = createJarAndLoad(classNode)) {
Object obj = cl.loadClass(STUB_CLASSNAME).newInstance();
assertEquals("OK!", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, 2, 2));
try {
MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, 0, 0);
fail();
} catch (InvocationTargetException ex) {
assertEquals("0", ex.getCause().getMessage());
}
try {
MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, 2, 10);
fail();
} catch (InvocationTargetException ex) {
assertEquals("innerdefault", ex.getCause().getMessage());
}
try {
MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, 10, 0);
fail();
} catch (InvocationTargetException ex) {
assertEquals("default", ex.getCause().getMessage());
}
}
}
@Test
public void mustCreateAndRunIfIntStatements() throws Exception {
// Augment signature
methodNode.desc = Type.getMethodDescriptor(Type.getType(String.class), new Type[] { Type.INT_TYPE, Type.INT_TYPE });
// Initialize variable table
VariableTable varTable = new VariableTable(classNode, methodNode);
Variable intVar1 = varTable.getArgument(1);
Variable intVar2 = varTable.getArgument(2);
// Update method logic
/**
* if (arg1 == arg2) {
* return "match";
* }
* return "nomatch";
*/
methodNode.instructions
= merge(ifIntegersEqual(loadVar(intVar1),
loadVar(intVar2),
returnValue(Type.getType(String.class), loadStringConst("match"))),
returnValue(Type.getType(String.class), loadStringConst("nomatch"))
);
// Write to JAR file + load up in classloader -- then execute tests
try (URLClassLoader cl = createJarAndLoad(classNode)) {
Object obj = cl.loadClass(STUB_CLASSNAME).newInstance();
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, 2, 2));
assertEquals("nomatch", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, -2, 2));
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, -2, -2));
}
}
@Test
public void mustCreateAndRunIfObjectStatements() throws Exception {
// Augment signature
methodNode.desc = Type.getMethodDescriptor(
Type.getType(String.class),
new Type[] { Type.getType(Object.class), Type.getType(Object.class) });
// Initialize variable table
VariableTable varTable = new VariableTable(classNode, methodNode);
Variable intVar1 = varTable.getArgument(1);
Variable intVar2 = varTable.getArgument(2);
// Update method logic
/**
* if (arg1 == arg2) {
* return "match";
* }
* return "nomatch";
*/
methodNode.instructions
= merge(
ifObjectsEqual(
loadVar(intVar1),
loadVar(intVar2),
returnValue(Type.getType(String.class), loadStringConst("match"))),
returnValue(Type.getType(String.class), loadStringConst("nomatch"))
);
Object testObj1 = "test1";
Object testObj2 = "test2";
// Write to JAR file + load up in classloader -- then execute tests
try (URLClassLoader cl = createJarAndLoad(classNode)) {
Object obj = cl.loadClass(STUB_CLASSNAME).newInstance();
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, testObj1, testObj1));
assertEquals("nomatch", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, testObj1, testObj2));
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, testObj2, testObj2));
}
}
@Test
public void mustCreateAndRunForEachStatement() throws Exception {
// Augment signature
methodNode.desc = Type.getMethodDescriptor(Type.getType(String.class), new Type[] {
Type.getType(Object[].class),
Type.getType(Object.class)
});
methodNode.maxLocals += 2; // We've added 2 parameters to the method, and we need to upgrade maxLocals or else varTable will give
// us bad indexes for variables we grab with acquireExtra(). This is because VariableTable uses maxLocals
// to determine at what point to start adding extra local variables.
// Initialize variable table
VariableTable varTable = new VariableTable(classNode, methodNode);
Variable objectArrVar = varTable.getArgument(1);
Variable searchObjVar = varTable.getArgument(2);
Variable counterVar = varTable.acquireExtra(Type.INT_TYPE);
Variable arrayLenVar = varTable.acquireExtra(Type.INT_TYPE);
Variable tempObjectVar = varTable.acquireExtra(Object.class);
// Update method logic
/**
* for (Object[] o : arg1) {
* if (o == arg2) {
* return "match";
* }
* }
* return "nomatch";
*/
methodNode.instructions
= merge(
forEach(counterVar, arrayLenVar,
loadVar(objectArrVar),
merge(
saveVar(tempObjectVar),
ifObjectsEqual(loadVar(tempObjectVar), loadVar(searchObjVar),
returnValue(Type.getType(String.class), loadStringConst("match")))
)
),
returnValue(Type.getType(String.class), loadStringConst("nomatch"))
);
// Write to JAR file + load up in classloader -- then execute tests
try (URLClassLoader cl = createJarAndLoad(classNode)) {
Object obj = cl.loadClass(STUB_CLASSNAME).newInstance();
Object o1 = new Object();
Object o2 = new Object();
Object o3 = new Object();
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, (Object) new Object[] { o1, o2, o3 }, o1));
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, (Object) new Object[] { o1, o2, o3 }, o2));
assertEquals("match", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, (Object) new Object[] { o1, o2, o3 }, o3));
assertEquals("nomatch", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, (Object) new Object[] { o1, o2, o3 }, null));
assertEquals("nomatch", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME, (Object) new Object[] { o1, o2, o3 }, new Object()));
}
}
@Test
public void mustConstructAndCall() throws Exception {
// Augment signature
methodNode.desc = Type.getMethodDescriptor(Type.getType(String.class), new Type[] { });
// Initialize variable table
VariableTable varTable = new VariableTable(classNode, methodNode);
Variable sbVar = varTable.acquireExtra(StringBuilder.class);
Variable retVar = varTable.acquireExtra(String.class);
// Update method logic
/**
* return new StringBuilder().append("hi!").toString()
*/
methodNode.instructions
= merge(
construct(StringBuilder.class.getConstructor()),
saveVar(sbVar),
call(StringBuilder.class.getMethod("append", String.class), loadVar(sbVar), loadStringConst("hi!")),
call(StringBuilder.class.getMethod("toString"), loadVar(sbVar)),
saveVar(retVar),
returnValue(Type.getType(String.class), loadVar(retVar))
);
// Write to JAR file + load up in classloader -- then execute tests
try (URLClassLoader cl = createJarAndLoad(classNode)) {
Object obj = cl.loadClass(STUB_CLASSNAME).newInstance();
assertEquals("hi!", MethodUtils.invokeMethod(obj, STUB_METHOD_NAME));
}
}
}