/* Copyright 2003-2004, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.sequential;
import edu.cmu.minorthird.classify.*;
import gnu.trove.map.hash.TObjectDoubleHashMap;
import gnu.trove.procedure.TObjectProcedure;
import gnu.trove.set.hash.THashSet;
import java.io.Serializable;
import java.util.*;
/**
* A more space-efficient version of a CandidateSegmentGroup.
*
* Space is saved by explicitly storing the instances for the
* unit-length segments, plus "deltas" for each non-unit length segment.
* Each "delta" encodes the difference between the segment instance
* and the sum of the unit-length instances it covers.
*
* @author William Cohen
*/
public class CompactCandidateSegmentGroup implements CandidateSegmentGroup,
Serializable{
static final long serialVersionUID=20080207L;
private int maxWindowSize,sequenceLength,totalSize;
private Set<String> classNameSet;
private String subPopId;
// the segment from start to start+L is window[start][L-1].
private Instance[] unitInstance;
private Delta[][] delta;
private ClassLabel[][] label;
private Object[][] segmentSource;
/** Creates a new holder for sliding-window instances. */
public CompactCandidateSegmentGroup(FeatureFactory factory,
CandidateSegmentGroup group){
// The length of the original sequence
this.sequenceLength=group.getSequenceLength();
// The maximum length of any sliding window
this.maxWindowSize=group.getMaxWindowSize();
this.totalSize=group.size();
this.classNameSet=group.classNameSet();
this.subPopId=group.getSubpopulationId();
unitInstance=new Instance[sequenceLength];
delta=new Delta[sequenceLength][maxWindowSize];
label=new ClassLabel[sequenceLength][maxWindowSize];
segmentSource=new Object[sequenceLength][maxWindowSize];
for(int i=0;i<sequenceLength;i++){
unitInstance[i]=factory.compress(group.getSubsequenceInstance(i,i+1));
}
for(int i=0;i<sequenceLength;i++){
for(int j=i+1;j-i<=maxWindowSize;j++){
if(group.getSubsequenceInstance(i,j)!=null){
label[i][j-i-1]=group.getSubsequenceLabel(i,j);
segmentSource[i][j-i-1]=group.getSubsequenceInstance(i,j).getSource();
delta[i][j-i-1]=
new Delta(factory,i,j,group.getSubsequenceInstance(i,j));
}
}
}
}
//
// helpers to construct feature iterators and/or compute weights
//
/** The binary features in in any unitInstance between start...end
* or otherInstance. Equivalently, the features in the sum of
* {unitInstance[start],...,unitInstance[end-1],otherInstance}
*/
private Set<Feature> binaryFeatureSet(int start,int end,Instance otherInstance){
Set<Feature> s=new HashSet<Feature>();
for(int i=start;i<end;i++){
addAll(s,unitInstance[i].binaryFeatureIterator());
}
if(otherInstance!=null)
addAll(s,otherInstance.binaryFeatureIterator());
return s;
}
/** Analogous to binaryFeatureSet */
private Set<Feature> numericFeatureSet(int start,int end,Instance otherInstance){
Set<Feature> s=new HashSet<Feature>();
for(int i=start;i<end;i++){
addAll(s,unitInstance[i].numericFeatureIterator());
}
if(otherInstance!=null)
addAll(s,otherInstance.numericFeatureIterator());
return s;
}
/** Analogous to binaryFeatureSet */
private Set<Feature> featureSet(int start,int end,Instance otherInstance){
Set<Feature> s=new HashSet<Feature>();
s.addAll(binaryFeatureSet(start,end,otherInstance));
s.addAll(numericFeatureSet(start,end,otherInstance));
return s;
}
private void addAll(Set<Feature> s,Iterator<Feature> i){
while(i.hasNext())
s.add(i.next());
}
/** Get sum of weight of f over in all unitInstance between start and end */
private double getSumWeight(int start,int end,Feature f){
double w=0;
for(int i=start;i<end;i++){
w+=unitInstance[i].getWeight(f);
}
return w;
}
/** encode differences between a segmentInstance and the sum of the
* weights of the unit instances between start and end.
*/
private class Delta implements Serializable{
static final long serialVersionUID=20080207L;
public TObjectDoubleHashMap deltaWeight=new TObjectDoubleHashMap();
public THashSet zeroWeights=new THashSet();
public Delta(FeatureFactory factory,int start,int end,
Instance segmentInstance){
for(Iterator<Feature> i=featureSet(start,end,segmentInstance).iterator();i
.hasNext();){
Feature f=i.next();
// replace the feature with its canonical version, so
// that variant versions are not stored in the
// deltaWeight, zeroWeights hash tables
f=factory.getFeature(f);
double segmentWeight=segmentInstance.getWeight(f);
if(segmentWeight==0)
zeroWeights.add(f);
else{
double sumWeight=getSumWeight(start,end,f);
if(segmentWeight!=sumWeight)
deltaWeight.put(f,segmentWeight-sumWeight);
}
}
/*
System.out.println("segmentInstance: "+segmentInstance);
System.out.println("deltaInstance: "+new DeltaInstance(start,end,this,
segmentInstance.getSource(),
segmentInstance.getSubpopulationId()));
*/
}
}
/** Construct an instance from the unit instances and a delta.
*/
private class DeltaInstance extends AbstractInstance implements Serializable{
static final long serialVersionUID=20080207L;
private int start,end;
private Delta diff;
public DeltaInstance(int start,int end){
this.start=start;
this.end=end;
this.diff=delta[start][end-start-1];
this.source=segmentSource[start][end-start-1];
this.subpopulationId=subPopId;
}
// for debugging mostly
public DeltaInstance(int start,int end,Delta initDelta,Object initSource,
String initSubPopId){
this.start=start;
this.end=end;
this.diff=initDelta;
this.source=initSource;
this.subpopulationId=initSubPopId;
}
@Override
public double getWeight(Feature f){
if(diff.zeroWeights.contains(f))
return 0;
else
return getSumWeight(start,end,f)+diff.deltaWeight.get(f);
}
@Override
public Iterator<Feature> binaryFeatureIterator(){
return adjust(binaryFeatureSet(start,end,null),diff.zeroWeights,null);
}
@Override
public Iterator<Feature> numericFeatureIterator(){
return adjust(numericFeatureSet(start,end,null),diff.zeroWeights,
diff.deltaWeight);
}
@Override
public Iterator<Feature> featureIterator(){
return adjust(featureSet(start,end,null),diff.zeroWeights,
diff.deltaWeight);
}
@Override
public int numFeatures(){
System.err.println("numFeatures not implemented!");
return -1;
}
private Iterator<Feature> adjust(final Set<Feature> set,THashSet exclude,
TObjectDoubleHashMap include){
// like set.removeAll(exclude) but faster
exclude.forEach(new TObjectProcedure(){
@Override
public boolean execute(Object o){
set.remove(o);
return true; // indicates it's ok to invoke this procedure again
}
});
if(include!=null){
// like set.addAll( include ) but faster
include.forEachKey(new TObjectProcedure(){
@Override
public boolean execute(Object key){
set.add((Feature)key);
return true; // indicates it's ok to invoke this procedure again
}
});
}
return set.iterator();
}
}
//
// implement the rest of the interface...
//
@Override
public Example getSubsequenceExample(int start,int end){
if(end-start==1)
return new Example(unitInstance[start],label[start][0]);
else if(delta[start][end-start-1]!=null)
return new Example(new DeltaInstance(start,end),label[start][end-start-1]);
else
return null;
}
/** Return the class label associated with getSubsequenceExample(start,end).
*/
@Override
public ClassLabel getSubsequenceLabel(int start,int end){
return label[start][end-start-1];
}
/** Return the instance corresponding to the segment from positions start...end.
*/
@Override
public Instance getSubsequenceInstance(int start,int end){
if(end-start==1)
return new Example(unitInstance[start],label[start][0]);
else if(delta[start][end-start-1]!=null)
return new DeltaInstance(start,end);
else
return null;
}
@Override
public int getSequenceLength(){
return sequenceLength;
}
@Override
public int getMaxWindowSize(){
return maxWindowSize;
}
@Override
public String getSubpopulationId(){
return subPopId;
}
@Override
public int size(){
return totalSize;
}
@Override
public Set<String> classNameSet(){
return classNameSet;
}
}