/** SegmentAStar.java
*
* @author imran
* @since 1.2
* @version 1.3
*/
package iitb.CRF;
import gnu.trove.list.array.TIntArrayList;
import iitb.AStar.AStarSearch;
import iitb.AStar.BoundUpdate;
import iitb.AStar.State;
import iitb.CRF.SegmentViterbi.LabelConstraints;
import iitb.CRF.SparseViterbi.Iter;
import iitb.Utils.OptimizedSparseDoubleMatrix1D;
import iitb.Utils.OptimizedSparseDoubleMatrix2D;
import iitb.Utils.StaticObjectHeap;
import java.util.ArrayList;
import cern.colt.function.tdouble.IntIntDoubleFunction;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
public class SegmentAStar extends AStarInference {
private static final long serialVersionUID = 8124L;
double delta = 0.001;
double lowerBound = 0;
int forwardViterbiBeamSize = 1; //can set to any value
int backwardViterbiBeamSize = 1; //one is sufficient
DoubleMatrix2D[][] Mi;
DoubleMatrix1D Ri[][];
SparseViterbi.Context context[];
SegmentViterbi segmentViterbi;
SegmentViterbi backwardSegmentViterbi;
SegmentViterbi.LabelConstraints labelConstraints = null;
SegmentIterForward iter;
OptimizedSparseMatrixMapper stateGenerator;//not used
ArrayList<SegmentState> states;
CloneableIntSet nextLabelsOnPath = null;
int succEll, succPos;
Soln lbSoln;
OptimizedSparseDoubleMatrix2D optimizedSparseMi[][];
StaticHeapOptimizedSparseDoubleMatrix1D staticHeapOptSparseDoubleMatrix1D;
StaticHeapOptimizedSparseDoubleMatrix2D staticHeapOptSparseDoubleMatrix2D;
boolean sparseMatrix = false;
double lambda[] = null;
public SegmentAStar(SegmentCRF model, int bs) {
super();
beamsize = bs;
this.model = model;
getParameters();
aStar = new AStarSearch((boundUpdate?new AStarBoundUpdate() : null), avgStatesPerExpansion, maxExpansions, queueSizeLimit, debug);
segmentViterbi = new SegmentViterbi(model, forwardViterbiBeamSize);
backwardSegmentViterbi = new SegmentViterbi(model, backwardViterbiBeamSize);
iter = new SegmentIterForward(segmentViterbi.new SegmentIter());
states = new ArrayList<SegmentState>();
if(sparseMatrix){
staticHeapOptSparseDoubleMatrix1D = new StaticHeapOptimizedSparseDoubleMatrix1D(0);
staticHeapOptSparseDoubleMatrix2D = new StaticHeapOptimizedSparseDoubleMatrix2D(0);
stateGenerator = new OptimizedSparseMatrixMapper();
}
}
protected void getParameters(){
super.getParameters();
sparseMatrix = (Boolean.valueOf(model.params.miscOptions.getProperty("sparse", "false"))).booleanValue();
}
public double bestLabelSequence(CandSegDataSequence dataSeq, double lambda[]) {
double corrScore = aStarSearch(dataSeq, lambda, false);
int pos;
//check whether the search succeeded or not
int segmentCount = 0;//profiling
double score = goalState.g();
if (goalState != null && goalState.goalState()) {
do {
pos = goalState.pos;
dataSeq.setSegment(goalState.prevPos()+1,goalState.pos,goalState.y);
goalState = goalState.predecessor;
segmentCount++;
} while (goalState != null && goalState.pos >= 0);
assert (pos == 0);
return score;
} else {
//Error! Failure in A* search, finding solution using Viterbi
Soln soln = getViterbiSoln(dataSeq, lambda, (SegmentState)goalState);
if(soln == null || (lbSoln != null && Double.compare(soln.score, lbSoln.score) < 0)){
soln = lbSoln;
}
if(soln != null){
score = soln.score;
Soln ybest = soln;
while (ybest != null) {
pos = ybest.pos;
dataSeq.setSegment(ybest.prevPos()+1,ybest.pos,ybest.label);
ybest = ybest.prevSoln;
}
}
}
return score;
}
public double aStarSearch(DataSequence dataSeq, double lambda[],
boolean calcCorrectScore) {
return aStarSearch(dataSeq, lambda, calcCorrectScore, true);
}
public double aStarSearch(DataSequence dataSeq, double lambda[],
boolean calcCorrectScore, boolean checkConstraints) {
labelConstraints = null;
if(checkConstraints)
labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
this.dataSeq = dataSeq;
this.lambda = lambda;
//init the forward-edge iterator
iter.init(dataSeq);
//get upper and lower bound for solution
if (!getSolutionBound(dataSeq, lambda)) {
goalState = null;
return 0;
}
//perform AStar search
goalState = (AStarState) aStar.performAStarSearch(getStartState());
return (goalState != null ? goalState.g() : 0);
}
private boolean getSolutionBound(DataSequence dataSeq, double[] lambda) {
cacheMis(dataSeq, lambda);
//get lower bound using constrained viterbi
Soln soln = null;
segmentViterbi.viterbiSearch(dataSeq, lambda, Mi, Ri, true, false);
soln = segmentViterbi.getBestSoln(0);
lbSoln = copyViterbiSolution(soln);
lowerBound = (lbSoln != null ? lbSoln.score : 0);//assume that viterbi returns some solution
//get upper bound using backward-viterbi, to be used for heuristics
backwardSegmentViterbi.viterbiSearchBackward(dataSeq, lambda,Mi, Ri, false);
ubSoln = backwardSegmentViterbi.getBestSoln(0);
upperBound = (ubSoln != null ? ubSoln.score : 0);
//store context from backward viterbi for heuristic calculations
context = backwardSegmentViterbi.context;
return true;
}
private Soln copyViterbiSolution(Soln soln) {
Soln tempSoln = null, copiedSoln = null, lastSoln = null;
while(soln != null){
tempSoln = new Soln(soln.label, soln.pos);
tempSoln.score = soln.score;
if(lastSoln != null)
lastSoln.prevSoln = tempSoln;
if(copiedSoln == null){
copiedSoln = tempSoln;
}
soln = soln.prevSoln;
lastSoln = tempSoln;
}
return copiedSoln;
}
private AStarState getStartState(){
return new SegmentState(-1, 0, -1, upperBound, 0, null, (labelConstraints != null ? new CloneableIntSet() : null));
}
Soln getViterbiSoln(DataSequence dataSeq, double lambda[], SegmentState curState){
Soln soln = getSoln(curState);
cacheMis(dataSeq, lambda);
segmentViterbi.viterbiSearch(dataSeq, lambda, Mi, Ri, soln, true, false);
return segmentViterbi.getBestSoln(0);
}
private void cacheMis(DataSequence dataSeq, double[] lambda2) {
segmentViterbi.cacheMis(dataSeq, lambda);
Mi = segmentViterbi.getMis();
Ri = segmentViterbi.getRis();
if(sparseMatrix){
createOptimizedSparseMatrices(Mi);
}
}
private void createOptimizedSparseMatrices(DoubleMatrix2D Mi[][]) {
optimizedSparseMi = new OptimizedSparseDoubleMatrix2D[Mi.length][];
for(int i = 0; i < Mi.length; i++){
optimizedSparseMi[i] = new OptimizedSparseDoubleMatrix2D[Mi[i].length];
for(int j = 0; j < Mi[i].length; j++){
optimizedSparseMi[i][j] = (OptimizedSparseDoubleMatrix2D) staticHeapOptSparseDoubleMatrix2D.getObject();
stateGenerator.init(optimizedSparseMi[i][j]);
Mi[i][j].forEachNonZero(stateGenerator);
}
}
}
private Soln getSoln(SegmentState curState) {
Soln nextSoln = null, curSoln = null, soln = null;
while(curState!= null && curState.pos >=0){
curSoln = new Soln(curState.y, curState.pos);
curSoln.score = (float) curState.g();
if(nextSoln != null)
nextSoln.setPrevSoln(curSoln, nextSoln.score);
else{
soln = curSoln;
}
nextSoln = curSoln;
curState = (SegmentState) curState.predecessor;
}
return soln;
}
class StaticHeapOptimizedSparseDoubleMatrix1D extends StaticObjectHeap{
public StaticHeapOptimizedSparseDoubleMatrix1D(int initCapacity) {
super(initCapacity);
}
protected Object newObject() {
return new OptimizedSparseDoubleMatrix1D();
}
protected Object getObject() {
return getFreeObject();
}
}
class StaticHeapOptimizedSparseDoubleMatrix2D extends StaticObjectHeap{
public StaticHeapOptimizedSparseDoubleMatrix2D(int initCapacity) {
super(initCapacity);
}
protected Object newObject() {
return new OptimizedSparseDoubleMatrix2D();
}
protected Object getObject() {
return getFreeObject();
}
}
class SegmentIterForward extends SparseViterbi.Iter {
int nc;
DataSequence dataSeq;
Iter iter;
int startPos, index;
TIntArrayList segments[];
public SegmentIterForward(Iter iter){
segmentViterbi.super();
this.iter = iter;
}
public void init(DataSequence dataSeq){
this.dataSeq = dataSeq;
segments = new TIntArrayList[dataSeq.length()];
for(int j = dataSeq.length()-1; j >= 0; j--)
segments[j] = new TIntArrayList();
cacheEdges();
}
private void cacheEdges() {
int ell = 0;
for(int i = dataSeq.length()-1; i >= 0; i--){
iter.start(i, dataSeq);
while((ell = iter.nextEll(i)) > 0){
if(i - ell >= 0){
segments[i - ell].add(ell);
}else{
segments[segments.length-1].add(ell);
}
}
}
}
protected void start(int i, DataSequence dataSeq) {
startPos = i;
index = (i == -1 ? segments[segments.length - 1].size() : segments[i].size());
}
protected int nextEll(int i) {
if(i == -1)
return index > 0 ? segments[segments.length-1].get(--index) : -1 ;
return (index > 0 ? segments[i].get(--index) : -1 );
}
}
class OptimizedSparseMatrixMapper implements IntIntDoubleFunction{
private OptimizedSparseDoubleMatrix2D sparse2D;
private OptimizedSparseDoubleMatrix1D sparse1D;
public OptimizedSparseMatrixMapper(){
}
public void init(OptimizedSparseDoubleMatrix2D sparse2D){
this.sparse2D = sparse2D;
}
public double apply(int yp, int yi, double val) {
if((sparse1D = sparse2D.getRow(yp)) == null){
sparse1D = (OptimizedSparseDoubleMatrix1D) staticHeapOptSparseDoubleMatrix1D.getObject();
sparse1D.clear();
sparse2D.setRow(yp, sparse1D);
}
sparse1D.setQuick(yi, val);
return val;
}
}
int stateCount = 0;
class AStarBoundUpdate implements BoundUpdate{
public double getLowerBound(State curState) {
Soln soln = getViterbiSoln(dataSeq, lambda, (SegmentState)curState);
if(soln != null && soln.score > lowerBound)
lowerBound = soln.score;
return lowerBound - 2 * delta;
}
}
class SegmentState extends AStarState implements OptimizedSparseDoubleMatrix1D.ForEachNonZeroReadOnly{
int id;
int ell;//pos is used as end of a segment whose "length" "ell"
public SegmentState(int pos, int ell, int label, double h, double g, AStarState predecessor, CloneableIntSet labelsOnPath) {
super(pos, label, h, g, predecessor, null);
this.ell = ell;
id = stateCount++;
this.labelsOnPath = labelsOnPath;
}
State[] generateSucessors() {
SegmentState successors[];
iter.start(pos, dataSeq);
states.clear();
while((succEll = iter.nextEll(pos)) > 0){
succPos = pos + succEll;
if(pos != -1){
if(sparseMatrix)
createSuccessors(optimizedSparseMi[succPos][succEll].getRow(y));
else
createSuccessors(Mi[succPos][succEll].viewRow(y));
}else
createSuccessors(Ri[succPos][succEll]);
}
successors = new SegmentState[states.size()];
for(int i = 0; i < states.size(); i++)
successors[i] = states.get(i);
return successors;
}
private void createSuccessors(DoubleMatrix1D miRow){
SegmentState successor;
double succG, succH;
if(labelConstraints != null){
nextLabelsOnPath = (CloneableIntSet) labelsOnPath.clone();
if(y != -1 && labelConstraints.conflicting(y))
nextLabelsOnPath.add(y);
}
for(int yi = (int) (miRow.size() - 1); yi >= 0; yi--){
double val = miRow.getQuick(yi);
if(val == 0 ||
(labelConstraints != null && prevPos() >= 0 && !labelConstraints.valid(labelsOnPath, yi, y))){
continue;
}
succG = (pos == -1 ? val + g : val+Ri[succPos][succEll].get(yi) + g);
succH = context[succPos + 1].getEntry(yi).solns[0].score;
if(Double.compare((succG + succH + delta), lowerBound) >= 0){
successor = new SegmentState(succPos, succEll, yi, succH, succG, this, nextLabelsOnPath);
states.add(successor);
}
}
}
private void createSuccessors(OptimizedSparseDoubleMatrix1D optimizedMiRow){
SegmentState successor;
double succG, succH;
if(labelConstraints != null){
nextLabelsOnPath = (CloneableIntSet) labelsOnPath.clone();
if(y != -1 && labelConstraints.conflicting(y))
nextLabelsOnPath.add(y);
}
optimizedMiRow.forEachNonZero(this);
}
public void apply(int yi, double val){
SegmentState successor;
if((labelConstraints != null && prevPos() >= 0 && !labelConstraints.valid(labelsOnPath, yi, y))){
double succG = (pos == -1 ? val + g : val+Ri[succPos][succEll].get(yi) + g);
double succH = context[succPos + 1].getEntry(yi).solns[0].score;
return;
}
double succG = (pos == -1 ? val + g : val+Ri[succPos][succEll].get(yi) + g);
double succH = context[succPos + 1].getEntry(yi).solns[0].score;
if(Double.compare((succG + succH + delta), lowerBound) >= 0){
successor = new SegmentState(succPos, succEll, yi, succH, succG, this, nextLabelsOnPath);
states.add(successor);
}
}
public boolean goalState(){
return pos == (dataSeq.length() - 1);
}
public String toString() {
return id + ">Pos:"
+ pos
+ " ell="
+ ell
+ " Label="
+ y
+ " score="
+ (g-(predecessor != null ? predecessor.g() : 0))
+ " g="
+ g
+ " h="
+ h
+ " f="
+ (h + g)
+ (predecessor != null ? " Par=(" + predecessor.pos + ", "
+ predecessor.y + ")" : "");
}
}
public static void main(String[] args) {
}
}