/** AStarInference.java
*
* @author Imran Mansuri
* @since 1.2
* @version 1.3
*
* A* search
*/
package iitb.CRF;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import iitb.AStar.AStarSearch;
import iitb.AStar.State;
import iitb.CRF.Viterbi.Entry;
import iitb.Model.Model;
import java.io.Serializable;
import java.util.BitSet;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D;
public class AStarInference implements Serializable {
private static final long serialVersionUID = 81236L;
CRF model;
Viterbi viterbi;
int beamsize;
DataSequence dataSeq;
TIntObjectHashMap<BitSet> conflictingLabels;
Model graphModel;
DoubleMatrix2D Mi[];
DoubleMatrix1D Ri[];
boolean constraintCheck = true;
Soln ubSoln; //Upper bound solution, without constraints
int path[];
float[] ubScores;
float ubScore[][];
float upperBound = 0;
AStarState goalState;
AStarSearch aStar;
long maxExpansions = Long.MAX_VALUE;
long queueSizeLimit = Long.MAX_VALUE;
int avgStatesPerExpansion = 50;
boolean boundUpdate = false;
int forwardViterbiBeamSize = 1;
int backwardViterbiBeamSize = 1;
boolean debug = false;
protected class CloneableIntSet extends TIntHashSet implements Cloneable{
public Object clone() {
try {
return super.clone();
} catch (CloneNotSupportedException cnse) {
return null; // it's supported
}
}
}
protected AStarInference() {
}
public AStarInference(CRF model, int bs) {
this(model, bs, null, null);
getParameters();
}
protected void getParameters() {
if(model.params.miscOptions.getProperty("maxExpansions") != null){
try{
maxExpansions = Long.parseLong(model.params.miscOptions.getProperty("maxExpansions"));
}catch(NumberFormatException nfe){}
}
if(model.params.miscOptions.getProperty("queueSizeLimit") != null){
try{
queueSizeLimit = Long.parseLong(model.params.miscOptions.getProperty("queueSizeLimit"));
}catch(NumberFormatException nfe){}
}
if(model.params.miscOptions.getProperty("avgStatesPerExpansion") != null){
try{
avgStatesPerExpansion = Integer.parseInt(model.params.miscOptions.getProperty("avgStatesPerExpansions"));
}catch(NumberFormatException nfe){}
}
if(model.params.miscOptions.getProperty("boundUpdate") != null){
try{
boundUpdate = Boolean.valueOf(model.params.miscOptions.getProperty("boundUpdate")).booleanValue();
}catch(Exception nfe){}
}
if(model.params.miscOptions.getProperty("forwardViterbiBeamSize") != null){
try{
forwardViterbiBeamSize = Integer.parseInt(model.params.miscOptions.getProperty("forwardViterbiBeamSize"));
}catch(NumberFormatException nfe){}
}
if(model.params.miscOptions.getProperty("backwardViterbiBeamSize") != null){
try{
backwardViterbiBeamSize = Integer.parseInt(model.params.miscOptions.getProperty("backwardViterbiBeamSize"));
}catch(NumberFormatException nfe){}
}
if (model.params.miscOptions.getProperty("beamSize") != null)
{
try{
beamsize = Integer.parseInt(model.params.miscOptions
.getProperty("beamSize"));
}catch(NumberFormatException nfe){}
}
if(model.params.debugLvl > 2)
debug = true;
}
public AStarInference(CRF model, int bs, TIntObjectHashMap<TIntHashSet> confLabelMap,
Model graphModel) {
this.model = model;
this.graphModel = graphModel;
beamsize = bs;
getParameters();
aStar = new AStarSearch(null, avgStatesPerExpansion, maxExpansions, queueSizeLimit, debug);
viterbi = new Viterbi(model, forwardViterbiBeamSize);
Mi = new DenseDoubleMatrix2D[0];
Ri = new DenseDoubleMatrix1D[0];
initConflictLables(confLabelMap);
}
private void initConflictLables(TIntObjectHashMap<TIntHashSet> confLabelMap) {
if (confLabelMap == null || confLabelMap.size() == 0)
return;
conflictingLabels = new TIntObjectHashMap<BitSet>();
int keys[] = confLabelMap.keys();
TIntHashSet labelSet;
BitSet bitSet;
for (int i = 0; i < keys.length; i++) {
labelSet = (TIntHashSet) confLabelMap.get(keys[i]);
bitSet = new BitSet();
int labelArray[] = labelSet.toArray();
for (int j = 0; j < labelArray.length; j++) {
bitSet.set(labelArray[j]);
}
conflictingLabels.put(keys[i], bitSet);
}
}
int nonMatchCount = 0;
/*
* Equivalent of viterbiSearch
*/
public void bestLabelSequence(DataSequence dataSeq, double lambda[]) {
double corrScore = aStarSearch(dataSeq, lambda, true);
//constraintCheck = false;
int pos;
//check whether the search succeeded or not
boolean nonMatch = false;
int lastPos = 0, lastLabel = 0;
if (goalState != null) {
do {
pos = goalState.pos;
dataSeq.set_y(pos, goalState.y);
goalState = goalState.predecessor;
} while (goalState != null && goalState.pos >= 0);
assert (pos == 0);
} else {
System.err.println("Error! Failure in A* search");
}
return;
}
Entry winningLabel[][];
public double aStarSearch(DataSequence dataSeq, double lambda[],
boolean calcCorrectScore) {
this.dataSeq = dataSeq;
// allocate data structures
allocateScratch(model.numY, dataSeq.length());
if (!getUpperBoundSolution(dataSeq, lambda)) {
goalState = null;
return 0;
}
//perform AStar search
goalState = (AStarState) aStar.performAStarSearch(getStartState());
return goalState.g();
//return 0;
}
float scores[];
Soln lastUbSoln;
private boolean getUpperBoundSolution(DataSequence dataSeq, double[] lambda) {
int seqLength = dataSeq.length(), pos = 0;
viterbi.viterbiSearchBackward(dataSeq, lambda, Mi, Ri, false);
winningLabel = viterbi.winningLabel;
ubScore = new float[dataSeq.length()][model.numY];
for (pos = 0; pos < dataSeq.length(); pos++) {
for (int y = 0; y < model.numY; y++) {
ubScore[pos][y] = winningLabel[y][pos].get(0).score;
}
}
return true;
}
/*
* Store all Mi, Ri matrices
*/
void fillArray(DataSequence dataSeq, double lambda[], boolean calcScore) {
int numY = model.numY;
computeMi(dataSeq, lambda);
}
private void computeMi(DataSequence dataSeq, double lambda[]) {
int seqLength = dataSeq.length();
for (int pos = 0; pos < seqLength; pos++) {
// compute Mi.
Trainer.computeLogMi(model.featureGenerator, lambda, dataSeq, pos,
Mi[pos], Ri[pos], false);
}
}
void allocateScratch(int numY, int seqLength) {
if (Mi.length < seqLength) { //what if Mi.length > seqLength, I hope
DoubleMatrix2D tempMi[] = Mi;
DoubleMatrix1D tempRi[] = Ri;
int i;
Mi = new DenseDoubleMatrix2D[seqLength];
Ri = new DenseDoubleMatrix1D[seqLength];
for (i = 0; i < tempMi.length; i++) {
Mi[i] = tempMi[i];
Ri[i] = tempRi[i];
}
for (; i < seqLength; i++) {
Mi[i] = new DenseDoubleMatrix2D(numY, numY);
Ri[i] = new DenseDoubleMatrix1D(numY);
}
}
}
private AStarState getStartState() {
return (conflictingLabels == null ? new AStarState(-1, -1, upperBound,
0, null, null) : new AStarState(-1, -1, upperBound, 0, null,
new BitSet()));
}
int numSolutions() {
return 1;
}//bs not supported for now
Soln getBestSoln(int k) {
return null;
}
public TIntObjectHashMap<BitSet> getConflictingLabels() {
return conflictingLabels;
}
public void setConflictingLabels(TIntObjectHashMap<BitSet> conflictingLabels) {
this.conflictingLabels = conflictingLabels;
}
class AStarState extends State {
int pos;
int y;
boolean valid = true;
AStarState predecessor;
BitSet assignedLabels;
protected CloneableIntSet labelsOnPath;
public AStarState(int pos, int label, double h, double g,
AStarState predecessor, BitSet assignedLabels) {
super(g, h);
this.pos = pos;
this.predecessor = predecessor;
this.y = label;
this.assignedLabels = assignedLabels;
checkValidity();
}
public double estimate() {
return h + g;
}
public State[] getSuccessors() {
return generateSucessors();
}
State[] generateSucessors() {
AStarState[] successors = new AStarState[model.numY];
int nextPos = pos + 1;
//for (int nextY = model.edgeGen.first(label); nextY < model.numY;
// nextY = model.edgeGen.next(nextY,label)){
BitSet succLabels = null;
if (assignedLabels != null) {
succLabels = (BitSet) assignedLabels.clone();
if (y >= 0) {
succLabels.set(graphModel.label(y));
}
}
for (int nextY = 0; nextY < model.numY; nextY++) {
double succScore = 0;
succScore += (nextPos > 0 ? Mi[nextPos].get(y, nextY) : 0);
succScore += Ri[nextPos].get(nextY) + g;
if (assignedLabels != null) {
if (!conflicting(succLabels, nextY))
successors[nextY] = new AStarState(nextPos, nextY,
ubScore[nextPos][nextY], succScore, this,
succLabels);//todo
else
successors[nextY] = null;//invalid state
} else
successors[nextY] = new AStarState(nextPos, nextY,
ubScore[nextPos][nextY], succScore, this, null);
}
return successors;
}
public double g() {
return g;
}
public double h() {
return h;
}
public boolean goalState() {
return pos == (dataSeq.length() - 1);
}
public boolean validState() {
return valid;
}
public void checkValidity() {
valid = !conflicting(assignedLabels, y);
}
public boolean conflicting(BitSet assignedLabels, int stateId) {
if (constraintCheck && assignedLabels != null
&& conflictingLabels.get(graphModel.label(stateId)) != null) {
return (assignedLabels.intersects((BitSet) conflictingLabels
.get(graphModel.label(stateId))));
}
return false;
}
public String toString() {
return "Pos:"
+ pos
+ " Y="
+ graphModel.label(y)
+ " h="
+ h
+ " g="
+ g
+ " f="
+ (h + g)
+ (predecessor != null ? " Par=(" + predecessor.pos + ", "
+ predecessor.y + ")" : "");
}
public int prevPos() {
if(predecessor == null)
return -1;
else
return predecessor.pos;
}
}
};