/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.sequential;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.StringUtil;
/**
* A conditional markov model classifier.
*
* @author William Cohen
*/
public class BeamSearcher implements SequenceConstants,Serializable{
static private final long serialVersionUID=20080207L;
static private boolean OLD_VERSION=false;
static private Logger log=Logger.getLogger(BeamSearcher.class);
static private final boolean DEBUG=false;
// parameters of beam searcher
private int historySize;
private String[] possibleClassLabels;
private Classifier classifier;
private int beamSize=10;
// caches current beam search
transient private Beam beam=new Beam();
private boolean caching=false;
transient private Instance[] instances;
transient private int numInstances;
transient private String[] history;
public BeamSearcher(Classifier classifier,int historySize,ExampleSchema schema){
this.classifier=classifier;
this.historySize=historySize;
this.possibleClassLabels=schema.validClassNames();
// System.out.println(Arrays.toString(possibleClassLabels));
if(possibleClassLabels.length<2)
throw new IllegalArgumentException("possibleClassLabels.length="+
possibleClassLabels.length+" <2 ???");
}
public int getMaxBeamSize(){
return beamSize;
}
public void setMaxBeamSize(int n){
beamSize=n;
}
public boolean isCaching(){
return caching;
}
public void setCaching(boolean caching){
this.caching=caching;
}
/** Get the best label sequence, as determined by the beam search */
public ClassLabel[] bestLabelSequence(Instance[] instances){
doSearch(instances);
return viterbi(0);
}
static public Instance getBeamInstance(Instance instance,int historySize){
String[] history=new String[historySize];
InstanceFromSequence.fillHistory(history,new String[]{},0);
return new InstanceFromSequence(instance,history);
}
/** Do a beam search. */
public void doSearch(Instance[] sequence){
this.instances=sequence;
if(DEBUG)
log.debug("searching over a "+sequence.length+"-instance sequence");
if(DEBUG)
log.debug("beamSize="+beamSize+" classes="+
StringUtil.toString(possibleClassLabels));
if(possibleClassLabels.length<2)
throw new IllegalStateException("possibleClassLabels.length="+
possibleClassLabels.length+" <2 ???");
history=new String[historySize];
beam=new Beam();
beam.add(new BeamEntry());
for(int i=0;i<instances.length;i++){
if(DEBUG)
log.debug("predicting class for instance["+i+"]: "+
instances[i].getSource());
Beam nextBeam=new Beam();
for(int j=0;j<Math.min(beam.size(),beamSize);j++){
BeamEntry entry=beam.get(j);
if(DEBUG)
log.debug("beam entry["+j+"]: "+entry);
// classify example based on this history
Instance beamInstance=entry.getBeamInstance(instances[i]);
ClassLabel label=classifier.classification(beamInstance);
if(DEBUG)
log.debug("class for "+beamInstance+" => "+label);
// add all possible classifications to the beam for the next iteration
for(int el=0;el<possibleClassLabels.length;el++){
double w=label.getWeight(possibleClassLabels[el]);
nextBeam.add(entry.extend(possibleClassLabels[el],w));
if(DEBUG)
log.debug("extending beam with "+possibleClassLabels[el]+
" score: "+w);
}
}
nextBeam.sort();
beam=nextBeam;
}
numInstances=this.instances.length;
if(!caching){
this.instances=null;
}
}
/**
* Do a beam search, constraining the bestLabel for each classification to
* match the non-null values in the template.
*
*<p>
* This would be better folded in with the one-arg version of doSearch, but is
* kept separate for backward compatibility.
*/
public void doSearch(Instance[] sequence,ClassLabel[] template){
this.instances=sequence;
if(DEBUG)
log.debug("searching over a "+sequence.length+"-instance sequence");
if(DEBUG)
log.debug("beamSize="+beamSize+" classes="+
StringUtil.toString(possibleClassLabels));
if(possibleClassLabels.length<2)
throw new IllegalStateException("possibleClassLabels.length="+
possibleClassLabels.length+" <2 ???");
history=new String[historySize];
beam=new Beam();
beam.add(new BeamEntry());
for(int i=0;i<instances.length;i++){
if(DEBUG)
log.debug("predicting class for instance["+i+"]: "+
instances[i].getSource());
Beam nextBeam=new Beam();
for(int j=0;j<Math.min(beam.size(),beamSize);j++){
BeamEntry entry=beam.get(j);
if(DEBUG)
log.debug("beam entry["+j+"]: "+entry);
// classify example based on this history
Instance beamInstance=entry.getBeamInstance(instances[i]);
ClassLabel label=classifier.classification(beamInstance);
if(DEBUG)
log.debug("class for "+beamInstance+" => "+label);
// add all possible classifications to the beam for the next iteration
for(int el=0;el<possibleClassLabels.length;el++){
if(template.length<i+1||template[i]==null||
template[i].bestClassName().equals(possibleClassLabels[el])){
double w=label.getWeight(possibleClassLabels[el]);
nextBeam.add(entry.extend(possibleClassLabels[el],w));
if(DEBUG)
log.debug("extending beam with "+possibleClassLabels[el]+
" score: "+w);
}else{
if(DEBUG)
log.debug("discarding "+possibleClassLabels[el]+
" as template mismatch");
}
}
}
nextBeam.sort();
beam=nextBeam;
}
numInstances=this.instances.length;
if(!caching){
this.instances=null;
}
}
/** Return the number of solutions found in the beam. */
public int getNumberOfSolutionsFound(){
return beam.size();
}
/**
* Retrieve the k-th best result of the previous beam search. To get the best,
* use viterbi(0), the second best is viterbi(1), etc.
*/
public ClassLabel[] viterbi(int k){
ClassLabel[] result=new ClassLabel[numInstances];
BeamEntry entry=beam.get(k);
for(int i=0;i<numInstances;i++){
result[i]=entry.toClassLabel(i);
}
return result;
}
public float score(int k){
return (float)beam.get(k).score;
}
public String explain(Instance[] sequence){
StringBuffer buf=new StringBuffer("");
doSearch(sequence);
BeamEntry targetEntry=beam.get(0);
BeamEntry entry=new BeamEntry();
for(int i=0;i<sequence.length;i++){
buf.append("Classification for instance "+i+" is "+targetEntry.labels[i]+
" (score "+targetEntry.scores[i]+"):\n");
buf.append(classifier.explain(entry.getBeamInstance(sequence[i])));
entry=entry.extend(targetEntry.labels[i],targetEntry.scores[i]);
buf.append("\nRunning total score: "+entry.score+"\n\n");
}
return buf.toString();
}
public Explanation getExplanation(Instance[] sequence){
doSearch(sequence);
BeamEntry targetEntry=beam.get(0);
BeamEntry entry=new BeamEntry();
Explanation.Node top=new Explanation.Node("BeamSearcher Classification");
for(int i=0;i<sequence.length;i++){
Explanation.Node seqEx=
new Explanation.Node("Classification for instance "+i+" is "+
targetEntry.labels[i]+" (score "+targetEntry.scores[i]+"):\n");
Explanation.Node explan=
classifier.getExplanation(sequence[i]).getTopNode();
if(explan==null)
explan=
new Explanation.Node(classifier.explain(entry
.getBeamInstance(sequence[i])));
seqEx.add(explan);
entry=entry.extend(targetEntry.labels[i],targetEntry.scores[i]);
Explanation.Node score=
new Explanation.Node("\nRunning total score: "+entry.score+"\n\n");
seqEx.add(score);
top.add(seqEx);
}
Explanation ex=new Explanation(top);
return ex;
}
/** The search space. */
private class Beam{
private List<BeamEntry> list=new ArrayList<BeamEntry>();
// maps last historySize labels ->
private Map<BeamKey,BeamEntry> keyMap=new HashMap<BeamKey,BeamEntry>();
public BeamEntry get(int i){
return list.get(i);
}
public void add(BeamEntry entry){
BeamKey key=new BeamKey(entry);
BeamEntry existingEntry=keyMap.get(key);
if(existingEntry==null||existingEntry.score<entry.score){
if(existingEntry!=null)
list.remove(existingEntry);
list.add(entry);
keyMap.put(key,entry);
}
}
public int size(){
return list.size();
}
public void sort(){
Collections.sort(list);
}
}
/** A single entry in the beam */
private class BeamEntry implements Comparable<BeamEntry>{
/* Labels assigned so far. */
public String[] labels=new String[0];
/* Score associated with each label assigned so far. */
public double[] scores=new double[0];
/** Total score of labels so far */
public double score=0.0;
/** Implement Comparable */
@Override
public int compareTo(BeamEntry other){
return MathUtil.sign(other.score-score);
}
/** Convert i-th label stored to a class label */
public ClassLabel toClassLabel(int i){
return new ClassLabel(labels[i],scores[i]);
}
/** Extend the beam with one additional label */
public BeamEntry extend(String label,double labelScore){
BeamEntry result=new BeamEntry();
result.labels=new String[labels.length+1];
result.scores=new double[labels.length+1];
for(int i=0;i<labels.length;i++){
result.labels[i]=labels[i];
result.scores[i]=scores[i];
}
result.labels[labels.length]=label;
result.scores[labels.length]=labelScore;
result.score=score+labelScore;
return result;
}
public Instance getBeamInstance(Instance instance){
fillHistory(history);
return new InstanceFromSequence(instance,history);
}
public void fillHistory(String[] history){
InstanceFromSequence.fillHistory(history,labels,labels.length);
}
@Override
public String toString(){
return "[entry: "+labels+";"+scores+"; score:"+score+"]";
}
}
/**
* Used to look for BeamEntry's that should be combined. BeamEntrys should be
* combined in the beam (with the higher score being retained) if their
* history is the same.
*/
private class BeamKey{
private String[] keyHistory=new String[historySize];
public BeamKey(BeamEntry entry){
entry.fillHistory(keyHistory);
}
@Override
public int hashCode(){
int h=73643674;
for(int i=0;i<keyHistory.length;i++){
if(OLD_VERSION)
h=h^keyHistory.hashCode();
else
h=h^keyHistory[i].hashCode();
}
return h;
}
@Override
public boolean equals(Object o){
if(!(o instanceof BeamKey))
return false;
BeamKey b=(BeamKey)o;
if(!(b.keyHistory.length==keyHistory.length))
return false;
for(int i=0;i<b.keyHistory.length;i++){
if(!keyHistory[i].equals(b.keyHistory[i]))
return false;
}
return true;
}
@Override
public String toString(){
String path="[Key ";
for(int i=0;i<keyHistory.length;i++){
path+=(keyHistory[i]+" ");
}
path+="]";
return path;
}
}
}