/* * IntervalElement.java - This file is part of the Jakstab project. * Copyright 2007-2015 Johannes Kinder <jk@jakstab.org> * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, see <http://www.gnu.org/licenses/>. */ package org.jakstab.analysis.intervals; import java.util.*; import org.jakstab.analysis.*; import org.jakstab.rtl.BitVectorType; import org.jakstab.rtl.expressions.ExpressionFactory; import org.jakstab.rtl.expressions.RTLNumber; import org.jakstab.util.Characters; import org.jakstab.util.FastSet; import org.jakstab.util.Logger; /** * A reduced signed (twos complement) bitvector interval element with region and stride information. * An IntervalElement represents the numbers * {region:(left + k*stride) | k \in [0 .. (right-left)/stride]} * For intervals of size 1, the stride is 0 by definition. The interval is reduced, i.e. its right * limit is included (the size of the interval is a multiple of its stride). * All values are represented as signed long integers, with left < right. * * @author Johannes Kinder */ public class IntervalElement implements AbstractDomainElement, BitVectorType, Iterable<Long> { private static final Logger logger = Logger.getLogger(IntervalElement.class); private static IntervalElement TOP1 = new IntervalElement(MemoryRegion.GLOBAL, -1, 0, 1, 1); private static IntervalElement TOP8 = new IntervalElement(MemoryRegion.GLOBAL, Byte.MIN_VALUE, Byte.MAX_VALUE, 1, 8); private static IntervalElement TOP16 = new IntervalElement(MemoryRegion.GLOBAL, Short.MIN_VALUE, Short.MAX_VALUE, 1, 16); private static IntervalElement TOP32 = new IntervalElement(MemoryRegion.TOP, Integer.MIN_VALUE, Integer.MAX_VALUE, 1, 32); private static IntervalElement TOP64 = new IntervalElement(MemoryRegion.TOP, Long.MIN_VALUE, Long.MAX_VALUE, 1, 64); // Not quite right, but should not matter private static IntervalElement TOP80 = new IntervalElement(MemoryRegion.TOP, Long.MIN_VALUE, Long.MAX_VALUE, 1, 80); private static final int MAX_CONCRETIZATION_SIZE = 100; public static IntervalElement TRUE = new IntervalElement(ExpressionFactory.TRUE); public static IntervalElement FALSE = new IntervalElement(ExpressionFactory.FALSE); public static IntervalElement getTop(int bitWidth) { switch (bitWidth) { case 1: return TOP1; case 8: return TOP8; case 16: return TOP16; case 32: return TOP32; case 64: return TOP64; case 80: return TOP80; default: throw new RuntimeException("No top interval element with bitwidth " + bitWidth); } } private final long left; private final long right; private final int bitWidth; private final long stride; private final MemoryRegion region; public IntervalElement(RTLNumber number) { this(number, number); } public IntervalElement(RTLNumber startNumber, RTLNumber endNumber) { this(MemoryRegion.GLOBAL, startNumber, endNumber); } public IntervalElement(MemoryRegion region, RTLNumber startNumber, RTLNumber endNumber) { this(region, startNumber.longValue(), endNumber.longValue(), 1, startNumber.getBitWidth()); assert startNumber.getBitWidth() == endNumber.getBitWidth(); } public IntervalElement(MemoryRegion region, RTLNumber number) { this(region, number.longValue(), number.longValue(), 0, number.getBitWidth()); } public IntervalElement(MemoryRegion region, long left, long right, long stride, int bitWidth) { this.region = region; this.left = left; this.right = right; // For single-value intervals, set stride to zero if (right - left == 0) { this.stride = 0; } else { this.stride = stride; } this.bitWidth = bitWidth; assert stride != 0 || right == left : "Stride 0 for interval of size > 1!"; assert stride == 0 || (right - left) % stride == 0 : "Stride " + stride + " does not fit with interval bounds: " + left + ";" + right + "!"; } public long getLeft() { return left; } public long getRight() { return right; } public long size() { if (stride == 0) return 1; return (right - left)/stride + 1; } public MemoryRegion getRegion() { return region; } @Override public int getBitWidth() { return bitWidth; } public long getStride() { return stride; } /** * Widening by extending the interval bounds infinitely into the direction of * the parameter. Takes care to preserve singleton-property of top elements. * * @param towards the target towards which this element should be widened. * @return the result of the widening, which is greater than this and the * parameter interval. */ public IntervalElement widen(IntervalElement towards) { assert towards.bitWidth == bitWidth; IntervalElement result; IntervalElement top = getTop(bitWidth); if (region != towards.region) return top; long newStride = joinStride(towards); if (towards.left < left) { if (towards.right > right || rightOpen()) { result = top; } else { result = new IntervalElement(region, top.getLeft() + (right - top.getLeft()) % newStride, right, newStride, bitWidth); } } else { if (towards.right > right) { if (leftOpen()) { result = top; } else { result = new IntervalElement(region, left, top.getRight() - (top.getRight() - left) % newStride, newStride, bitWidth); } } else { if (newStride > stride) result = new IntervalElement(region, left, right, newStride, bitWidth); else result = this; } } // if (this != result) logger.debug("Widening " + this + " to " + result); return result; } /* * @see org.jakstab.analysis.AbstractValue#concretize() */ @Override public Set<RTLNumber> concretize() { // magic max size for jump tables if (getRegion() != MemoryRegion.GLOBAL || size() > MAX_CONCRETIZATION_SIZE) { return RTLNumber.ALL_NUMBERS; } Set<RTLNumber> result = new FastSet<RTLNumber>(); if (stride == 0) { result.add(ExpressionFactory.createNumber(left, bitWidth)); } else { for (long v = left; v <= right; v+=stride) { result.add(ExpressionFactory.createNumber(v, bitWidth)); } } return result; } /* * @see org.jakstab.analysis.AbstractValue#join(org.jakstab.analysis.LatticeElement) */ @Override public IntervalElement join(LatticeElement lt) { IntervalElement other = (IntervalElement)lt; assert bitWidth == other.bitWidth; if (isTop() || other.isBot()) return this; if (isBot() || other.isTop()) return other; if (this.region != other.region) return getTop(bitWidth); if (this.leftOpen() && other.rightOpen()) return getTop(bitWidth); if (other.leftOpen() && this.rightOpen()) return getTop(bitWidth); long newStride = joinStride(other); long l = Math.min(this.left, other.left); long r = Math.max(this.right, other.right); return new IntervalElement(this.region, l, r, newStride, bitWidth); } private long joinStride(IntervalElement other) { long newStride; // If both intervals were size 1, set stride to difference if (this.stride == 0 && other.stride == 0) newStride = Math.abs(other.left - this.left); else { newStride = gcdStride(stride, other.stride); newStride = gcdStride(newStride, Math.abs(left - other.left)); } return newStride; } /* * @see org.jakstab.analysis.LatticeElement#isBot() */ @Override public boolean isBot() { return false; } /* * @see org.jakstab.analysis.LatticeElement#isTop() */ @Override public boolean isTop() { //return this == getTop(bitWidth); return getTop(bitWidth).equals(this); } /* * @see org.jakstab.analysis.LatticeElement#lessOrEqual(org.jakstab.analysis.LatticeElement) */ @Override public boolean lessOrEqual(LatticeElement l) { if (isBot()) return true; IntervalElement other = (IntervalElement)l; assert bitWidth == other.bitWidth; return other.left <= this.left && other.right >= this.left && other.stride <= this.stride; } /* * @see java.lang.Object#toString() */ @Override public String toString() { if (isTop()) return Characters.TOP; StringBuilder s = new StringBuilder(); if (region != MemoryRegion.GLOBAL) s.append(region).append(":"); s.append(stride); s.append('['); if (leftOpen()) { s.append(Characters.TOP); } else { s.append(left); } if (size() == 1) { s.append(']'); } else { s.append(';'); if (rightOpen()) { s.append(Characters.TOP); } else { s.append(right); } s.append(']'); } return s.toString(); } /* * @see java.lang.Object#hashCode() */ @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + bitWidth; result = prime * result + (int)stride; result = prime * result + (int) (left ^ (left >>> 32)); result = prime * result + ((region == null) ? 0 : region.hashCode()); result = prime * result + (int) (right ^ (right >>> 32)); return result; } /* * @see java.lang.Object#equals(java.lang.Object) */ @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; IntervalElement other = (IntervalElement) obj; if (bitWidth != other.bitWidth) return false; if (stride != other.stride) return false; if (left != other.left) return false; if (region == null) { if (other.region != null) return false; } else if (!region.equals(other.region)) return false; if (right != other.right) return false; return true; } public boolean leftOpen() { return left == getTop(bitWidth).getLeft(); } public boolean rightOpen() { return right == getTop(bitWidth).getRight(); } /** * Addition of two strided interval elements. Code adapted from Gogul Balakrishnan's thesis. * * @param op the interval to add to this interval * @return A new interval that is the sum of this and the given interval */ @Override public IntervalElement plus(AbstractDomainElement other) { IntervalElement op = (IntervalElement)other; assert bitWidth == op.bitWidth; MemoryRegion newRegion = region.join(op.region); if (newRegion.isTop()) return getTop(bitWidth); long l = this.left + op.left; long r = this.right + op.right; long u = this.left & op.left & ~l & ~(this.right & op.right & ~r); long v = ((this.left ^ op.left) | ~(this.left ^ l)) & (~this.right & ~op.right & r); if ((u | v) < 0 || (u | v) > getTop(bitWidth).right) { return getTop(bitWidth); } return new IntervalElement(newRegion, l, r, gcdStride(stride, op.stride), bitWidth); } @Override public IntervalElement negate() { if (left == getTop(bitWidth).left) { // If this interval is just the minimum value, it's negation is the same value again in 2s complement. if (left == right) return this; return getTop(bitWidth); } else { return new IntervalElement(region, -right, -left, stride, bitWidth); } } private static long gcdStride(long s1, long s2) { if (s1 == 0) return s2; else if (s2 == 0) return s1; else return gcd(s1, s2); } private static long gcd(long a, long b) { long r; do { r = a % b; a = b; b = r; } while (r > 0); return a; } /** * Multiplies two interval elements. * 2[0;4] * 1[3;6] = 2[0;24] * 0,2,4 * 3,4,5,6 = 0,6,8,10,12,16,20,24 * * 3[0;9] * 0[6;6] = 18[0;54] * 0,3,6,9 * 5 = 0,18,36,54 * * 3[-3;6] * 2[4;8] = 6[-24;48] * -3,0,3,6 * 4,6,8 = -24,-18,-12,0,12,18,24,36,48 * @param op the other interval element to multiply this element with. * @return a new interval element which is the result of the multiplication. */ @Override public IntervalElement multiply(AbstractDomainElement other) { IntervalElement op = (IntervalElement)other; MemoryRegion newRegion = region.join(op.region); int newBitWidth = bitWidth * 2; IntervalElement top = getTop(newBitWidth); // Cannot multiply a pointer if (newRegion != MemoryRegion.GLOBAL) return top; // Try all combinations of mutltiplying the bounds // yeah, it's a hack, but it works. long[] b = new long[4]; b[0] = getLeft() * op.getLeft(); b[1] = getLeft() * op.getRight(); b[2] = getRight() * op.getLeft(); b[3] = getRight() * op.getRight(); Arrays.sort(b); if (b[0] <= top.getLeft() || b[3] >= top.getRight()) return top; long newStride; if (stride == 0) { if (op.stride == 0) newStride = 0; else { assert left == right; newStride = op.stride * Math.abs(left); } } else if (op.stride == 0) { assert op.left == op.right; newStride = Math.abs(op.left) * stride; } else { newStride = stride * op.stride; } return new IntervalElement(region, b[0], b[3], newStride, newBitWidth); } public IntervalElement bitExtract(int first, int last) { int newBitWidth = last - first + 1; if (region == MemoryRegion.GLOBAL) { // Check if operand is already within bit range. No bit range // extraction is actually performed. IntervalElement top = IntervalElement.getTop(newBitWidth); if (first == 0 && left >= top.getLeft() && right <= top.getRight()) { IntervalElement newElement = new IntervalElement(MemoryRegion.GLOBAL, left, right, stride, newBitWidth); return newElement; } } return IntervalElement.getTop(newBitWidth); } public IntervalElement signExtend(int first, int last) { int newBitWidth = Math.max(bitWidth, last + 1); if (region == MemoryRegion.GLOBAL) { // If region is global, return the value with new bit width return new IntervalElement(region, left, right, stride, newBitWidth); } return IntervalElement.getTop(newBitWidth); } public IntervalElement zeroFill(int first, int last) { int newBitWidth = Math.max(bitWidth, last + 1); if (region == MemoryRegion.GLOBAL && left >= 0 && right < (1 << first)) { // If value is non-negative and does not set any // bits that are zeroed out, it is unmodified return new IntervalElement(region, left, right, stride, newBitWidth); } return IntervalElement.getTop(newBitWidth); } @Override public Iterator<Long> iterator() { return new Iterator<Long>() { long cursor = left; long incr = stride == 0 ? 1 : stride; public void remove() { throw new UnsupportedOperationException(); } public Long next() { if (!hasNext()) throw new NoSuchElementException(); long r = cursor; cursor += incr; return r; } public boolean hasNext() { return cursor <= right; } }; } @Override public boolean hasUniqueConcretization() { return size() == 1; } @Override public AbstractDomainElement readStore(int bitWidth, PartitionedMemory<? extends AbstractDomainElement> store) { if (isTop()) return IntervalElement.getTop(bitWidth); long offset = getLeft(); AbstractDomainElement res = store.get(getRegion(), offset, bitWidth); if (getStride() > 0) { offset+=getStride(); for (;offset <= getRight(); offset += getStride()) { if (res.isTop()) { logger.info("Joined intervals to TOP while reading memory range " + this); return res; } AbstractDomainElement v = store.get(getRegion(), offset, bitWidth); res = res.join(v); } } return res; } @Override public Collection<? extends AbstractDomainElement> readStorePowerSet( int bitWidth, PartitionedMemory<? extends AbstractDomainElement> store) { if (isTop() || size() > MAX_CONCRETIZATION_SIZE) return Collections.singleton(IntervalElement.getTop(bitWidth)); Set<AbstractDomainElement> res = new FastSet<AbstractDomainElement>(); for (long offset : this) { res.add(store.get(getRegion(), offset, bitWidth)); } return res; } @Override public <A extends AbstractDomainElement> void writeStore(int bitWidth, PartitionedMemory<A> store, A value) { if (!getRegion().isSummary() && size() == 1) { // Strong update store.set(getRegion(), getLeft(), bitWidth, value); } else { // Weak update if (getRegion().isTop() || size() > 100) store.setTop(getRegion()); else { for (long i : this) { store.weakUpdate(getRegion(), i, bitWidth, value); } } } } }