/* * VpcTrackingAnalysis.java - This file is part of the Jakstab project. * Copyright 2007-2015 Johannes Kinder <jk@jakstab.org> * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, see <http://www.gnu.org/licenses/>. */ package org.jakstab.analysis.explicit; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import org.jakstab.AnalysisProperties; import org.jakstab.JOption; import org.jakstab.Program; import org.jakstab.analysis.AbstractState; import org.jakstab.analysis.CPAOperators; import org.jakstab.analysis.ConfigurableProgramAnalysis; import org.jakstab.analysis.MemoryReference; import org.jakstab.analysis.MemoryRegion; import org.jakstab.analysis.PartitionedMemory; import org.jakstab.analysis.Precision; import org.jakstab.analysis.ReachedSet; import org.jakstab.analysis.ValueContainer; import org.jakstab.asm.AbsoluteAddress; import org.jakstab.cfa.CFAEdge; import org.jakstab.cfa.Location; import org.jakstab.cfa.StateTransformer; import org.jakstab.rtl.expressions.RTLVariable; import org.jakstab.rtl.statements.DefaultStatementVisitor; import org.jakstab.rtl.statements.RTLAssume; import org.jakstab.rtl.statements.RTLGoto; import org.jakstab.rtl.statements.RTLStatement; import org.jakstab.ssl.Architecture; import org.jakstab.util.Logger; import org.jakstab.util.Pair; import org.jakstab.util.MapMap.EntryIterator; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.SetMultimap; public class VpcTrackingAnalysis implements ConfigurableProgramAnalysis { private final static Logger logger = Logger.getLogger(VpcTrackingAnalysis.class); public static void register(AnalysisProperties p) { p.setShortHand('v'); p.setName("BAT-VPC"); p.setDescription("VPC-sensitive version of Bounded Address Tracking."); p.setExplicit(true); } public static JOption<String> vpcName = JOption.create("vpc", "r", "esi", "Register to be used as virtual program counter."); private static boolean procSensitiveVpc = true; private Map<Location, ValueContainer> vpcMap; private Map<AbsoluteAddress, Location> procedureMap; private Architecture arch; public VpcTrackingAnalysis() { vpcMap = new HashMap<Location, ValueContainer>(); procedureMap = new HashMap<AbsoluteAddress, Location>(); arch = Program.getProgram().getArchitecture(); } @Override public AbstractState merge(AbstractState s1, AbstractState s2, Precision precision) { // Reduces states, but makes it harder to reconstruct the trace that lead to a certain state if (s2.lessOrEqual(s1)) return s1; return CPAOperators.mergeSep(s1, s2, precision); } @Override public boolean stop(AbstractState s, ReachedSet reached, Precision precision) { return CPAOperators.stopSep(s, reached, precision); } @Override public Set<AbstractState> post(AbstractState state, final CFAEdge cfaEdge, Precision precision) { if (state.isBot()) return Collections.singleton(state); BasedNumberValuation b = (BasedNumberValuation)state; VpcPrecision vprec = (VpcPrecision)precision; final RTLStatement stmt = (RTLStatement)cfaEdge.getTransformer(); /* Do procedure analysis - for now, this is inlined here, should be made its own analysis */ stmt.accept(new DefaultStatementVisitor<Void>() { private void copyOldProcToTarget(Location target) { if (getProcedure(target) != null) return; Location oldProc = getProcedure(cfaEdge.getSource()); if (oldProc != null) setProcedure(target, oldProc); } @Override public Void visit(RTLAssume stmt) { if (stmt.isCall()) { // start new procedure unless we're calling into the middle of an existing one if (getProcedure(cfaEdge.getTarget()) == null) setProcedure(cfaEdge.getTarget(), cfaEdge.getTarget()); // Fall through edge in current proc RTLGoto gotoStmt = stmt.getSource(); if (gotoStmt.getNextLabel() != null) { copyOldProcToTarget(gotoStmt.getNextLabel()); } } else if (stmt.isReturn()) { // do nothing } else { // stay in same procedure copyOldProcToTarget(cfaEdge.getTarget()); } return null; } @Override protected Void visitDefault(RTLStatement stmt) { // includes call-return copyOldProcToTarget(cfaEdge.getTarget()); return null; } }); //logger.debug(cfaEdge.getTarget() + " is in procedure " + procedureMap.get(cfaEdge.getTarget())); // Will not hold if --basicblocks is enabled - precision is reused there for several locs //assert (cfaEdge.getTarget().equals(vprec.getLocation())); BasedNumberElement vpcValue = getVpcValue(b, getVpc(cfaEdge.getTarget())); ExplicitPrecision eprec = vprec.getPrecision(vpcValue); return b.abstractPost(stmt, eprec); } @Override public AbstractState strengthen(AbstractState s, Iterable<AbstractState> otherStates, CFAEdge cfaEdge, Precision precision) { return s; } public Location getProcedure(Location location) { return procedureMap.get(location.getAddress()); } private void setProcedure(Location location, Location procHead) { procedureMap.put(location.getAddress(), procHead); } public ValueContainer getVpc(Location location) { // No VPC for harness code if (Program.getProgram().getModule(location.getAddress()) == null) return null; if (procSensitiveVpc) location = getProcedure(location); return vpcMap.get(location); } private void setVpc(Location location, ValueContainer vpc) { if (procSensitiveVpc) location = getProcedure(location); vpcMap.put(location, vpc); } private BasedNumberElement getVpcValue(BasedNumberValuation s, ValueContainer vpc) { if (vpc == null) return BasedNumberElement.getTop(32); else return s.getValue(vpc); } @Override public Pair<AbstractState, Precision> prec(AbstractState s, Precision precision, ReachedSet reached) { if (s.isBot()) return Pair.create(s, precision); // This method uses the fact that there is only 1 precision per location VpcPrecision vprec = (VpcPrecision)precision; BasedNumberValuation widenedState = (BasedNumberValuation)s; Location loc = vprec.getLocation(); BasedNumberElement vpcValue = getVpcValue(widenedState, getVpc(loc)); ExplicitPrecision eprec = vprec.getPrecision(vpcValue); // If we don't have a VPC yet, first try to determine one from value counts. if (getVpc(loc) == null) { Multimap<Integer, ValueContainer> candidates = HashMultimap.create(); // No support for heap-based VPCs at the moment (merging heap contents speeds up convergence) int vpcThreshold = BoundedAddressTracking.varThreshold.getValue(); /*Math.min(BoundedAddressTracking.varThreshold.getValue(), BoundedAddressTracking.heapThreshold.getValue());*/ // Only check value counts if we have at least enough states to reach it if (reached.size() > vpcThreshold) { // Check value counts for variables for (RTLVariable v : eprec.varMap.keySet()) { Set<BasedNumberElement> existingValues = eprec.varMap.get(v); // Check first whether we should promote this var to VPC if (arch.isRegister(v) && existingValues.size() >= 2) { if (v.getName().equals("eax")) continue; candidates.put(existingValues.size(), v); } } // Check value counts for store PartitionedMemory<BasedNumberElement> sStore = ((BasedNumberValuation)s).getStore(); for (EntryIterator<MemoryRegion, Long, BasedNumberElement> entryIt = sStore.entryIterator(); entryIt.hasEntry(); entryIt.next()) { MemoryRegion region = entryIt.getLeftKey(); Long offset = entryIt.getRightKey(); SetMultimap<Long, BasedNumberElement> memoryMap = eprec.regionMaps.get(region); if (memoryMap == null) continue; Set<BasedNumberElement> existingValues = memoryMap.get(offset); if (existingValues.size() >= 2) { candidates.put(existingValues.size(), new MemoryReference(entryIt.getLeftKey(), entryIt.getRightKey(), existingValues.iterator().next().getBitWidth())); } } } if (!candidates.isEmpty()) { ArrayList<Integer> counts = new ArrayList<Integer>(candidates.keySet()); Collections.sort(counts, Collections.reverseOrder()); if (counts.get(0) >= vpcThreshold) { logger.verbose("Value threshold reached, choosing VPC for location " + loc + ". Candidates:"); for (Integer c : counts) logger.verbose(" " + c + ": " + candidates.get(c)); setVpc(loc, candidates.get(counts.get(0)).iterator().next()); logger.verbose(loc + ": Set VPC to " + getVpc(loc)); // Reload explicit precision for new VPC vpcValue = getVpcValue(widenedState, getVpc(loc)); eprec = vprec.getPrecision(vpcValue); } } } // Only check value counts if we have at least enough states to reach it if (reached.size() > Math.min(BoundedAddressTracking.varThreshold.getValue(), BoundedAddressTracking.heapThreshold.getValue())) { boolean changed = false; // Check value counts for variables for (RTLVariable v : eprec.varMap.keySet()) { //BasedNumberElement currentValue = ((BasedNumberValuation)s).getValue(v); Set<BasedNumberElement> existingValues = eprec.varMap.get(v); int threshold = eprec.getThreshold(v); if (existingValues.size() > threshold) { // Lower precisions and widen the value in this state, too. // This avoids values accumulating at join points (where they are not // intercepted by the precision-aware setValue) if (countRegions(existingValues) > threshold) { eprec.stopTracking(v); if (!changed) { widenedState = new BasedNumberValuation(widenedState); changed = true; } widenedState.setValue(v, BasedNumberElement.getTop(v.getBitWidth())); } else { eprec.trackRegionOnly(v); if (!changed) { widenedState = new BasedNumberValuation(widenedState); changed = true; } logger.debug("Only tracking region of " + v + ", values were " + existingValues); widenedState.setValue(v, new BasedNumberElement( widenedState.getValue(v).getRegion(), NumberElement.getTop(v.getBitWidth()))); } } } // Check value counts for store PartitionedMemory<BasedNumberElement> sStore = ((BasedNumberValuation)s).getStore(); for (EntryIterator<MemoryRegion, Long, BasedNumberElement> entryIt = sStore.entryIterator(); entryIt.hasEntry(); entryIt.next()) { MemoryRegion region = entryIt.getLeftKey(); Long offset = entryIt.getRightKey(); BasedNumberElement value = entryIt.getValue(); SetMultimap<Long, BasedNumberElement> memoryMap = eprec.regionMaps.get(region); if (memoryMap == null) continue; //BasedNumberElement currentValue = entry.getValue(); Set<BasedNumberElement> existingValues = memoryMap.get(offset); int threshold = eprec.getStoreThreshold(region, offset); if (existingValues.size() > threshold) { if (countRegions(existingValues) > 5*threshold) { eprec.stopTracking(region, offset); if (!changed) { widenedState = new BasedNumberValuation(widenedState); changed = true; } widenedState.getStore().set(region, offset, value.getBitWidth(), BasedNumberElement.getTop(value.getBitWidth())); } else { eprec.trackRegionOnly(region, offset); if (!changed) { widenedState = new BasedNumberValuation(widenedState); changed = true; } widenedState.getStore().set(region, offset, value.getBitWidth(), new BasedNumberElement(value.getRegion(), NumberElement.getTop(value.getBitWidth()))); } } } } // Collect all values for all variables for (Map.Entry<RTLVariable, BasedNumberElement> entry : widenedState.getVariableValuation()) { RTLVariable var = entry.getKey(); eprec.varMap.put(var, entry.getValue()); } // Collect all values for all memory areas PartitionedMemory<BasedNumberElement> store = widenedState.getStore(); for (EntryIterator<MemoryRegion, Long, BasedNumberElement> entryIt = store.entryIterator(); entryIt.hasEntry(); entryIt.next()) { SetMultimap<Long, BasedNumberElement> memoryMap = eprec.regionMaps.get(entryIt.getLeftKey()); if (memoryMap == null) { memoryMap = HashMultimap.create(); eprec.regionMaps.put(entryIt.getLeftKey(), memoryMap); } memoryMap.put(entryIt.getRightKey(), entryIt.getValue()); } // If it was changed, widenedState is now a new state return Pair.create((AbstractState)widenedState, precision); } @Override public AbstractState initStartState(Location location) { return BasedNumberValuation.createInitialState(); } @Override public Precision initPrecision(Location location, StateTransformer transformer) { VpcPrecision vpcPrec = new VpcPrecision(location); // Store precision locally in map so we can retrieve VPCs later return vpcPrec; } private int countRegions(Set<BasedNumberElement> values) { Set<MemoryRegion> regions = new HashSet<MemoryRegion>(); for (BasedNumberElement e : values) regions.add(e.getRegion()); return regions.size(); } }