/* * ExpressionFactory.java - This file is part of the Jakstab project. * Copyright 2007-2015 Johannes Kinder <jk@jakstab.org> * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, see <http://www.gnu.org/licenses/>. */ package org.jakstab.rtl.expressions; import java.util.*; import org.jakstab.util.Logger; import org.jakstab.asm.*; import org.jakstab.asm.x86.*; import com.google.common.collect.HashMultimap; import com.google.common.collect.SetMultimap; /** * Factory class for all RTL expressions. It is a singleton that holds the references to variables * and commonly used constants. * * @author Johannes Kinder */ public final class ExpressionFactory { private static final Logger logger = Logger.getLogger(ExpressionFactory.class); // This should be a multiple of 64 for use as bitset size public static final int DEFAULT_VARIABLE_COUNT = 128; // Initialize constants // Setting TRUE to -1 yields TRUE = ~FALSE, which makes life easier public static final RTLNumber TRUE = new RTLNumber(-1, 1); public static final RTLNumber FALSE = new RTLNumber(0, 1); public static final RTLVariable pc; public static final RTLVariable SKIP; public static final RTLVariable REPEAT; private static int uniqueVariableCount = 0; private static final Map<String, RTLVariable> variableInstances; private static ArrayList<RTLVariable> variableArray; private static final RTLNondet[] nondetArray; private static final Map<RTLVariable, RTLBitRange> sharedRegisterMap; private static final SetMultimap<RTLVariable, RTLVariable> coveredRegs; private static final SetMultimap<RTLVariable, RTLVariable> coveredBy; static { uniqueVariableCount = 0; variableInstances = new HashMap<String, RTLVariable>(DEFAULT_VARIABLE_COUNT); variableArray = new ArrayList<RTLVariable>(DEFAULT_VARIABLE_COUNT); nondetArray = new RTLNondet[128]; sharedRegisterMap = new HashMap<RTLVariable, RTLBitRange>(); coveredRegs = HashMultimap.create(); coveredBy = HashMultimap.create(); pc = createVariable("%pc", 32); SKIP = createVariable("%SKIP", 1); REPEAT = createVariable("%RPT", 1); } private ExpressionFactory() { } public static RTLBitRange createBitRange(RTLExpression operand, RTLExpression firstBit, RTLExpression lastBit) { return new RTLBitRange(operand, firstBit, lastBit); } public static RTLConditionalExpression createConditionalExpression( RTLExpression condition, RTLExpression trueExpression, RTLExpression falseExpression) { // Invert not-expressions if (condition instanceof RTLOperation && ((RTLOperation)condition).getOperator().equals(Operator.NOT)) { //logger.debug("Inverting negated conditional expression."); return createConditionalExpression(((RTLOperation)condition).getOperands()[0], falseExpression, trueExpression); } return new RTLConditionalExpression(condition, trueExpression, falseExpression); } public static RTLExpression createImplication(RTLExpression a, RTLExpression b) { return createOr(createNot(a), b); } public static RTLMemoryLocation createMemoryLocation(RTLExpression address, int bitWidth) { return createMemoryLocation(0, null, address, bitWidth); } public static RTLMemoryLocation createMemoryLocation(int memoryState, RTLExpression address, int bitWidth) { return createMemoryLocation(memoryState, null, address, bitWidth); } public static RTLMemoryLocation createMemoryLocation(RTLExpression segmentRegister, RTLExpression address, int bitWidth) { return createMemoryLocation(0, segmentRegister, address, bitWidth); } public static RTLMemoryLocation createMemoryLocation(int memoryState, RTLExpression segmentRegister, RTLExpression address, int bitWidth) { assert bitWidth > 0 : "Trying to create memory location of unknown width with address " + address + "!"; return new RTLMemoryLocation(memoryState, segmentRegister, address, bitWidth); } public static RTLNumber createNumber(Number value) { int bitWidth = 64; if (value instanceof Long) bitWidth = 64; else if (value instanceof Integer) bitWidth = 32; else if (value instanceof Short) bitWidth = 16; else if (value instanceof Byte) bitWidth = 8; return new RTLNumber(value.longValue(), bitWidth); } public static RTLNumber createNumber(long value, int bitWidth) { if (bitWidth == 1) { if (value == 0) return FALSE; else return TRUE; } return new RTLNumber(value, bitWidth); } public static RTLNumber createNumber(AbsoluteAddress addr) { return new RTLNumber(addr.getValue(), addr.getBitWidth()); } /** * Generic creation method that calls more specific methods depending on the * type of the assembly operand passed as parameter. * * @param iOp an operand of an assembly instruction * @return a translation of the operand into an RTLExpression */ public static RTLExpression createOperand(Operand iOp) { RTLExpression opAsExpr = null; if (iOp instanceof Immediate) { opAsExpr = createNumber(((Immediate)iOp).getNumber()); } else if (iOp instanceof Register) { opAsExpr = createRegister((Register)iOp); } else if (iOp instanceof MemoryOperand) { opAsExpr = createMemoryLocation((MemoryOperand)iOp); } else if (iOp instanceof AbsoluteAddress) { opAsExpr = createAddress((AbsoluteAddress)iOp); } else if (iOp instanceof PCRelativeAddress) { opAsExpr = createAddress((PCRelativeAddress)iOp); } else { if (iOp == null) logger.warn("Null operand in RTL conversion!"); else logger.warn("Unsupported operand type: " + iOp.getClass().getSimpleName()); } return opAsExpr; } public static RTLExpression createPlus(RTLExpression... operands) { return createOperation(Operator.PLUS, operands); } public static RTLExpression createPlus(RTLExpression op1, long op2) { return createPlus(op1, createNumber(op2, op1.getBitWidth())); } public static RTLExpression createMinus(RTLExpression op1, RTLExpression op2) { return createPlus(op1, createNeg(op2)); } public static RTLExpression createMultiply(RTLExpression... operands) { return createOperation(Operator.MUL, operands); } public static RTLExpression createFloatMultiply(RTLExpression... operands) { return createOperation(Operator.FMUL, operands); } public static RTLExpression createFloatDivide(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.FDIV, op1, op2); } public static RTLExpression createDivide(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.DIV, op1, op2); } public static RTLExpression createPowerOf(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.POWER_OF, op1, op2); } public static RTLExpression createModulo(RTLExpression... operands) { return createOperation(Operator.MOD, operands); } public static RTLExpression createEqual(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.EQUAL, op1, op2); } public static RTLExpression createNotEqual(RTLExpression op1, RTLExpression op2) { return createNot(createEqual(op1, op2)); } public static RTLExpression createLessThan(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.LESS, op1, op2); } public static RTLExpression createLessOrEqual(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.LESS_OR_EQUAL, op1, op2); } public static RTLExpression createGreaterThan(RTLExpression op1, RTLExpression op2) { return createLessThan(op2, op1); } public static RTLExpression createGreaterOrEqual(RTLExpression op1, RTLExpression op2) { //return createLessOrEqual(op2, op1); return createNot(createLessThan(op1, op2)); } public static RTLExpression createUnsignedLessThan(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.UNSIGNED_LESS, op1, op2); } public static RTLExpression createUnsignedLessOrEqual(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.UNSIGNED_LESS_OR_EQUAL, op1, op2); } public static RTLExpression createUnsignedGreaterThan(RTLExpression op1, RTLExpression op2) { return createUnsignedLessThan(op2, op1); } public static RTLExpression createUnsignedGreaterOrEqual(RTLExpression op1, RTLExpression op2) { return createNot(createUnsignedLessThan(op1, op2)); } public static RTLExpression createAnd(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.AND, op1, op2); } public static RTLExpression createOr(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.OR, op1, op2); } public static RTLExpression createXor(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.XOR, op1, op2); } public static RTLExpression createNot(RTLExpression op) { return createOperation(Operator.NOT, op); } public static RTLExpression createNeg(RTLExpression op) { return createOperation(Operator.NEG, op); } public static RTLExpression createShiftRight(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.SHR, op1, op2); } public static RTLExpression createShiftArithmeticRight(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.SAR, op1, op2); } public static RTLExpression createShiftLeft(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.SHL, op1, op2); } public static RTLExpression createRotateLeft(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.ROL, op1, op2); } public static RTLExpression createRotateRight(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.ROR, op1, op2); } public static RTLExpression createRotateLeftWithCarry(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.ROLC, op1, op2); } public static RTLExpression createRotateRightWithCarry(RTLExpression op1, RTLExpression op2) { return createOperation(Operator.RORC, op1, op2); } public static RTLExpression createCast(RTLExpression op, RTLNumber bitWidth) { return createOperation(Operator.CAST, op, bitWidth); } public static RTLExpression createSignExtend(int from, int to, RTLExpression op) { return createSignExtend(createNumber(from, 8), createNumber(to, 8), op); } public static RTLExpression createSignExtend(RTLExpression from, RTLExpression to, RTLExpression op) { return createOperation(Operator.SIGN_EXTEND, from, to, op); } public static RTLExpression createZeroFill(int from, int to, RTLExpression op) { return createZeroFill(createNumber(from, 8), createNumber(to, 8), op); } public static RTLExpression createZeroFill(RTLExpression from, RTLExpression to, RTLExpression op) { return createOperation(Operator.ZERO_FILL, from, to, op); } public static RTLExpression createFloatResize(RTLExpression toBits, RTLExpression fromBits, RTLExpression op) { //fromBits = createNumber(((RTLNumber)fromBits).intValue(), 8); //toBits = createNumber(((RTLNumber)toBits).intValue(), 8); return createOperation(Operator.FSIZE, toBits, fromBits, op); } public static RTLExpression createOperation(Operator operator, RTLExpression... operands) { switch (operator) { // Handle nested operators for commutative associative operations case PLUS: case AND: case OR: for (int i=0; i<operands.length; i++) { // Combine associative operands of same type if (operands[i] instanceof RTLOperation) { RTLOperation subOp = ((RTLOperation)operands[i]); if (subOp.getOperator() == operator) { RTLExpression[] newOps = new RTLExpression[operands.length - 1 + subOp.getOperandCount()]; System.arraycopy(operands, 0, newOps, 0, i); System.arraycopy(operands, i + 1, newOps, i, operands.length - i - 1); System.arraycopy(subOp.getOperands(), 0, newOps, operands.length - 1, subOp.getOperandCount()); return createOperation(operator, newOps); } } } break; // Cancel double negation/not case NOT: case NEG: if (operands[0] instanceof RTLOperation) { RTLOperation nestedOp = (RTLOperation)operands[0]; if (nestedOp.getOperator().equals(operator)) { return nestedOp.getOperands()[0]; } if (operator == Operator.NEG && nestedOp.getOperator() == Operator.PLUS) { RTLExpression[] newNestedOperands = new RTLExpression[nestedOp.getOperandCount()]; for (int i=0; i<nestedOp.getOperandCount(); i++) { newNestedOperands[i] = createNeg(nestedOp.getOperands()[i]); } return createOperation(nestedOp.getOperator(), newNestedOperands); } } break; default: // nothing } return new RTLOperation(operator, operands); } public static RTLExpression createSpecialExpression( String operation, RTLExpression... operands) { // Load effective address (lea) if (operation.equals("addr")) { if (operands[0] instanceof RTLMemoryLocation) { return ((RTLMemoryLocation)operands[0]).getAddress(); } } else if (operation.equals("nondet")) { return nondet(((RTLNumber)operands[0]).intValue()); } return new RTLSpecialExpression(operation, operands); } public static RTLVariable createVariable(String name, int bitWidth) { // Remove leading %-signs on registers to avoid confusion of users if (name.charAt(0) == '%') name = name.substring(1); RTLVariable var; if (variableInstances.containsKey(name)) { var = variableInstances.get(name); if (!(bitWidth == RTLVariable.UNKNOWN_BITWIDTH || var.getBitWidth() == bitWidth)) { if (!(name.startsWith("reg") || name.startsWith("modrm") || name.startsWith("i") || name.startsWith("sti"))) { logger.error(var + " exists with width " + var.getBitWidth() + "! Cannot make it width " + bitWidth + "!"); assert(false); } } } else { var = new RTLVariable(uniqueVariableCount, name, bitWidth); variableArray.add(var); assert(variableArray.get(var.getIndex()) == var) : "Something's wrong with variable caching!"; variableInstances.put(name, var); uniqueVariableCount++; assert uniqueVariableCount < DEFAULT_VARIABLE_COUNT : "Too many variables!"; } return var; } private static void addCoveringRegister(RTLVariable var, RTLVariable parent) { for (RTLVariable ancestor : coveringRegisters(parent)) { addCoveringRegister(var, ancestor); } coveredBy.put(var, parent); } private static void addCoveredRegister(RTLVariable var, RTLVariable child) { for (RTLVariable ancestor : coveringRegisters(var)) { addCoveredRegister(ancestor, child); } coveredRegs.put(var, child); } public static RTLVariable createSharedRegisterVariable(String name, String parentName, int startBit, int endBit) { RTLVariable var = createVariable(name, endBit - startBit + 1); RTLVariable parent = createVariable(parentName); RTLBitRange expr = createBitRange(parent, createNumber(startBit, 8), createNumber(endBit, 8)); sharedRegisterMap.put(var, expr); addCoveringRegister(var, parent); addCoveredRegister(parent, var); return var; } public static RTLBitRange getRegisterAsParent(RTLVariable var) { return sharedRegisterMap.get(var); } public static Set<RTLVariable> coveredRegisters(RTLVariable var) { return coveredRegs.get(var); } public static Set<RTLVariable> coveringRegisters(RTLVariable var) { return coveredBy.get(var); } public static Writable createRegisterVariable(String name, int bitWidth) { // Use explicit 16 and 8 bit registers now //if (sharedRegisterMap.containsKey(name)) return sharedRegisterMap.get(name); return createVariable(name, bitWidth); } public static RTLVariable createVariable(String name) { return createVariable(name, RTLVariable.UNKNOWN_BITWIDTH); } public static RTLVariable getVariable(int index) { return variableArray.get(index); } public static int getVariableCount() { return uniqueVariableCount; } /** * Returns an expression representing a nondeterministic value of the * given bit width. In Yices translation, each occurrence of a nondeterministic * expression is converted to a fresh variable. * * @param bitWidth * @return a nondeterministic RTLExpression. */ public static RTLExpression nondet(int bitWidth) { if (nondetArray[bitWidth - 1] == null) nondetArray[bitWidth - 1] = new RTLNondet(bitWidth); return nondetArray[bitWidth - 1]; } private static RTLExpression createAddress(AbsoluteAddress asmAddress) { RTLExpression addressExpression; addressExpression = createNumber(asmAddress.getValue(), asmAddress.getBitWidth()); return addressExpression; } private static RTLExpression createAddress(PCRelativeAddress asmAddress) { RTLExpression addressExpression; addressExpression = createNumber(asmAddress.getDisplacement(), asmAddress.getBitWidth()); return addressExpression; } private static RTLMemoryLocation createMemoryLocation(MemoryOperand asmMemOp) { RTLExpression segmentRegister = null; if (asmMemOp instanceof X86MemoryOperand) { X86SegmentRegister segReg = ((X86MemoryOperand)asmMemOp).getSegmentRegister(); if (segReg != null) segmentRegister = createRegister(segReg); } RTLExpression addressExpr = null; if (asmMemOp.getBase() != null) addressExpr = createRegister(asmMemOp.getBase()); if (asmMemOp.getIndex() != null) { RTLExpression indexScale = createRegister(asmMemOp.getIndex()); if (asmMemOp.getScale() > 1) { indexScale = createBitRange( createOperation(Operator.MUL, indexScale, createNumber(asmMemOp.getScale(), 32) ), createNumber(0,8), createNumber(31,8)); } if (addressExpr == null) addressExpr = indexScale; else addressExpr = createOperation(Operator.PLUS, addressExpr, indexScale); } if (asmMemOp.getDisplacement() != 0 || addressExpr == null) { RTLExpression disp = createNumber(asmMemOp.getDisplacement(), 32); if (addressExpr == null) addressExpr = disp; else addressExpr = createOperation(Operator.PLUS, addressExpr, disp); } assert (addressExpr != null) : "Address expression is null!"; int bitWidth = asmMemOp.getDataType().bits(); return createMemoryLocation(segmentRegister, addressExpr, bitWidth); } private static RTLExpression createRegister(Register asmRegister) { return createRegisterVariable(asmRegister.toString(), RTLVariable.UNKNOWN_BITWIDTH); } }