/* JPC: An x86 PC Hardware Emulator for a pure Java Virtual Machine Copyright (C) 2012-2013 Ian Preston This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License version 2 as published by the Free Software Foundation. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. Details (including contact information) can be found at: jpc.sourceforge.net or the developer website sourceforge.net/projects/jpc/ End of licence header */ package tools; import java.io.*; import java.util.*; import java.net.*; import java.lang.reflect.*; import javax.xml.parsers.*; import org.w3c.dom.*; import org.xml.sax.*; import org.xml.sax.helpers.DefaultHandler; public class Fuzzer { static String newJar = "JPCApplication.jar"; static String oldJar = "OldJPCApplication.jar"; public static final boolean compareFlags = true; public static void main(String[] args) throws Exception { URL[] urls = new URL[]{new File(newJar).toURL()}; ClassLoader cl1 = new URLClassLoader(urls, tools.Fuzzer.class.getClassLoader()); Class opts = cl1.loadClass("org.jpc.j2se.Option"); Method parse = opts.getMethod("parse", String[].class); args = (String[])parse.invoke(opts, (Object)args); PCHandle pc1 = new PCHandle(cl1, true, args); URL[] urls2 = new URL[]{new File(oldJar).toURL()}; ClassLoader cl2 = new URLClassLoader(urls2, tools.Fuzzer.class.getClassLoader()); PCHandle pc2 = new PCHandle(cl2, false, args); // will succeed //byte[] add_ah_al = new byte[] {(byte)0, (byte)0xc4}; //int[] input = new int[16]; //executeCase(add_ah_al, input, pc1, pc2, false, true); // will fail //byte[] imul_bx = new byte[] {(byte)0xf7, (byte)0xeb}; //int[] input2 = new int[16]; //input2[0] = 0x50; //input2[1] = 0x19; //input2[9] = 0x46; //executeCase(imul_bx, input2, pc1, pc2, false, true); SAXParserFactory factory = SAXParserFactory.newInstance(); SAXParser saxParser = factory.newSAXParser(); DefaultHandler rmhandler = new TestParser("rm", pc1, pc2, false, true); System.out.println("Starting Real Mode fuzzing..."); //saxParser.parse("tests/rm.tests", rmhandler); // set PCs to protected mode pc1.setPM(true); pc2.setPM(true); DefaultHandler pmhandler = new TestParser("pm", pc1, pc2, true, true); System.out.println("Starting Protected Mode fuzzing..."); saxParser.parse("tests/pm.tests", pmhandler); } public static boolean executeCase(String opclass, String disam, byte[] code, int[] initialState, PCHandle pc1, PCHandle pc2, boolean mem, boolean flags, BufferedWriter log) throws Exception { pc1.setState(initialState); pc2.setState(initialState); // load code at eip pc1.setCode(code); pc2.setCode(code); try { pc1.executeBlock(); } catch (InvocationTargetException e) { e.printStackTrace(); return false; } pc2.executeBlock(); doCompare(mem, flags, pc1, pc2, initialState, opclass, disam, code, log); return true; } public static void doCompare(boolean mem, boolean compareFlags, PCHandle newpc, PCHandle oldpc, int[] input, String opclass, String disam, byte[] code, BufferedWriter log) throws Exception { compareStates(input, opclass, disam, code, newpc.getState(), oldpc.getState(), compareFlags); if (!mem) return; byte[] data1 = new byte[4096]; byte[] data2 = new byte[4096]; for (int i=0; i < 1024*1024; i++) { Integer l1 = newpc.savePage(new Integer(i), data1); Integer l2 = oldpc.savePage(new Integer(i), data2); if (l2 > 0) if (!comparePage(i, data1, data2, log)) printAllStates(code, input, newpc.getState(), oldpc.getState(), opclass, disam); } } public static boolean comparePage(int index, byte[] fast, byte[] old, BufferedWriter log) throws IOException { if (fast.length != old.length) throw new IllegalStateException(String.format("different page data lengths %d != %d", fast.length, old.length)); for (int i=0; i < fast.length; i++) if (fast[i] != old[i]) { log.write(String.format("Difference in memory state: %08x=> %02x - %02x\n", index*4096+i, fast[i], old[i])); return false; } return true; } public static String[] names = EmulatorControl.names; public static void printState(int[] state, BufferedWriter out) throws IOException { StringBuilder builder = new StringBuilder(4096); Formatter formatter=new Formatter(builder); arrayImpl(names, state, formatter, 0, 10); arrayImpl(names, state, formatter, 10, 17); arrayImpl(names, state, formatter, 17, 24); arrayImpl(names, state, formatter, 24, 30); arrayImpl(names, state, formatter, 30, 37); arrayImpl(names, state, formatter, 37, 45); arrayImpl(names, state, formatter, 45, names.length); doubleImpl(names, state, formatter, 37, 37 + 16); out.flush(); out.write(builder.toString()); out.newLine(); } public static void printState(int[] state) { StringBuilder builder = new StringBuilder(4096); Formatter formatter=new Formatter(builder); arrayImpl(names, state, formatter, 0, 10); arrayImpl(names, state, formatter, 10, 17); arrayImpl(names, state, formatter, 17, 24); arrayImpl(names, state, formatter, 24, 30); arrayImpl(names, state, formatter, 30, 37); arrayImpl(names, state, formatter, 37, 45); arrayImpl(names, state, formatter, 45, names.length); doubleImpl(names, state, formatter, 37, 37+16); System.out.flush(); System.out.println(builder); } public static void printAllStates(byte[] code, int[] input, int[] fast, int[] old, String opclass, String disam) { System.out.print("**" + disam + " == " + opclass + " =: "); for (int i=0; i < code.length; i++) System.out.printf("%02x ", code[i]); System.out.println(); System.out.println("Input state:"); printState(input); System.out.println("New JPC state:"); printState(fast); System.out.println("Old JPC state:"); printState(old); } public static void doubleImpl(String[] names, int[] vals, Formatter f, int start, int end) { for (int i=start; i < end; i+=2) f.format("[%8s] ", "ST"+(i-start)/2); f.format("\n"); for (int i=start; i < end; i+=2) f.format("[%f] ", Double.longBitsToDouble((vals[i]&0xffffffffL) << 32 | (vals[i+1]&0xffffffffL))); f.format("\n"); } public static void arrayImpl(String[] names, int[] vals, Formatter f, int start, int end) { for (int i=start; i < end; i++) f.format("[%8s] ", names[i]); f.format("\n"); for (int i=start; i < end; i++) f.format("[%8X] ", vals[i]); f.format("\n"); } public static void compareStates(int[] input, String opclass, String disam, byte[] code, int[] fast, int[] old, boolean compareFlags) throws Exception { if (old.length != fast.length) throw new IllegalArgumentException("old state length = "+old.length+", new state length = "+fast.length); StringBuilder b = new StringBuilder(); for (int i=0; i < fast.length; i++) if (i != 9) { if (fast[i] != old[i]) { b.append(String.format("Difference: %d=%s %08x - %08x\n", i, names[i], fast[i], old[i])); //continueExecution(); } } else { if (compareFlags && ((fast[i] & FLAG_MASK) != (old[i] & FLAG_MASK))) { b.append(String.format("Difference: %d=%s %08x - %08x\n", i, names[i], fast[i], old[i])); //continueExecution(); } } if (b.length() > 0) { printAllStates(code, input, fast, old, opclass, disam); System.out.println(b.toString()); } } public static void continueExecution() { System.out.println("Ignore difference? (y/n)"); String line = null; try { line = new BufferedReader(new InputStreamReader(System.in)).readLine(); } catch (IOException f) { f.printStackTrace(); System.exit(0); } if (line.equals("y")) {} else System.exit(0); } public static final int FLAG_MASK = -1;//~0x10; public static final int gdtBase = 0xfb632; public static byte[] gdt = new byte[] { (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0xff, (byte)0xff, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x9b, (byte)0xcf, (byte)0x0, (byte)0xff, (byte)0xff, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x93, (byte)0xcf, (byte)0x0, (byte)0xff, (byte)0xff, (byte)0x0, (byte)0x0, (byte)0x0f, (byte)0x9b, (byte)0x0, (byte)0x0, (byte)0xff, (byte)0xff, (byte)0x0, (byte)0x0, (byte)0x0, (byte)0x93, (byte)0x0, (byte)0x0 }; public static byte[] lgdt = new byte[]{(byte)0x2e, (byte)0x0f, (byte)0x01, (byte)0x16, (byte)0x2c, (byte)0xb6}; public static int testEip = 0; public static int testCS = 0; public static Calendar start = Calendar.getInstance(); public static class PCHandle { final Object pc; final Method state, setState, executeBlock, savePage, setCode; public PCHandle(ClassLoader cl1, boolean isNew, String[] args) throws Exception { Class c1 = cl1.loadClass("org.jpc.emulator.PC"); Constructor ctor = c1.getConstructor(String[].class, Calendar.class); pc = ctor.newInstance((Object)args, start); Method m1 = c1.getMethod("hello"); m1.invoke(pc); state = c1.getMethod("getState"); setState = c1.getMethod("setState", int[].class); executeBlock = c1.getMethod("executeBlock"); savePage = c1.getMethod("getPhysicalPage", Integer.class, byte[].class); setCode = c1.getMethod("setCode", byte[].class); } public void setPM(boolean pm) throws Exception { byte[] setcr0 = new byte[] {(byte)0x0f, (byte)0x22, (byte)0xc0}; if (pm) { // setup gdt data /*int[] regs0 = new int[16]; regs0[10] = 0xf000;// set cs regs0[8] = 0xb632; // cheat and set eip to where the gdt will be to load gdt setState.invoke(pc, regs0); setCode(gdt); // refers to BIOS - data is already there*/ // lgdt int[] regs0 = new int[16]; regs0[10] = 0xf000;// set cs regs0[8] = 0xb599; // set eip to point to lgdt setState.invoke(pc, regs0); //setCode(lgdt); executeBlock(); // relies on single instruction length block // set cr0 executeBlock(); executeBlock(); executeBlock(); // far jump executeBlock(); // load other segments executeBlock();// mov eax, Iz executeBlock();//ds executeBlock();//es executeBlock();//ss // load test eip int[] newregs = getState(); testEip = newregs[8]; testCS = newregs[10]; // load cr0 /*int[] regs = new int[16]; regs[0] = 0x60000011; // new cr0 value is in eax regs0[10] = 0xf000;// set cs setState.invoke(pc, regs); setCode(setcr0); executeBlock();*/ } else { int[] regs = new int[16]; regs[0] = 0x60000010; setState.invoke(pc, regs); setCode(setcr0); executeBlock(); } } public void setCode(byte[] code) throws Exception { setCode.invoke(pc, (Object)code); } public int[] getState() throws Exception { return (int[]) state.invoke(pc); } public void setState(int[] s) throws Exception { setState.invoke(pc, s); } public void executeBlock() throws Exception { executeBlock.invoke(pc); } public Integer savePage(Integer page, byte[] buf) throws Exception { return (Integer)savePage.invoke(pc, page, (Object) buf); } } public static class TestParser extends DefaultHandler { final String mode; final PCHandle pc1, pc2; final boolean mem, flags; final Set<String> unimplemented = new HashSet<String>(); byte[] currentCode; String currentClass; String currentDisam; enum Type {None, Class, Code, Disam, Input} Type type; int opcodeCount=0, testCount=0; BufferedWriter log; public TestParser(String mode, PCHandle pc1, PCHandle pc2, boolean mem, boolean flags) { this.mode = mode; this.pc1 = pc1; this.pc2 = pc2; this.mem = mem; this.flags = flags; try { this.log = new BufferedWriter(new FileWriter("Fuzz_" + mode + (mem ? "mem" : ""))); } catch (IOException e) { e.printStackTrace(); } } public void startElement(String uri, String localName,String qName, Attributes attributes) throws SAXException { if (qName.equals("class")) type = Type.Class; else if (qName.equals("disam")) type = Type.Disam; else if (qName.equals("code")) type = Type.Code; else if (qName.equals("input")) type = Type.Input; } public void characters(char ch[], int start, int length) throws SAXException { if (type == Type.Class) currentClass = new String(ch, start, length); else if (type == Type.Disam) { currentDisam = new String(ch, start, length); System.out.println("Starting fuzz of "+currentDisam); } else if (type == Type.Code) { String[] codeArr = new String(ch, start, length).trim().split(" "); currentCode = new byte[codeArr.length]; for (int i=0; i < codeArr.length; i++) currentCode[i] = (byte)Integer.parseInt(codeArr[i], 16); } else if (type == Type.Input) { if (unimplemented.contains(currentClass)) return; String[] inputArr = new String(ch, start, length).trim().split(" "); int[] input = new int[names.length]; for (int i=0; i < inputArr.length; i++) input[i] = Integer.parseInt(inputArr[i], 16); // set eip // eip will be 0 which is fine //input[8] = testEip; input[10] = testCS; input[11] = 0x18; // ds input[12] = 0x18; // es input[15] = 0x18; // ss // now do the test case try { if (!executeCase(currentClass, currentDisam, currentCode, input, pc1, pc2, mem, flags, log)) unimplemented.add(currentClass); } catch (Exception e) {e.printStackTrace();} testCount++; if (testCount % 10000 == 0) System.out.printf("Completed %d test cases from %d opcodes in %s\n", testCount, opcodeCount, mode); } } public void endElement(String uri, String localName, String qName) throws SAXException { type = Type.None; } } }