/* Copyright 2006, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.relational;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SGMExample;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
/**
* Stacked Graphical Learning based on a BatchClassifier learner
*
* @author Zhenzhen Kou
*/
public class StackedGraphicalLearner extends StackedBatchClassifierLearner{
private static Logger log=Logger.getLogger(StackedGraphicalLearner.class);
private ExampleSchema schema;
private BatchClassifierLearner baseLearner;
private StackingParams params;
/** Bundle of parameters for the StackedGraphicalLearner. */
public static class StackingParams{
public int stackingDepth=1;
public boolean useLogistic=true,useTargetPrediction=true,
useConfidence=true;
public Splitter<Example> splitter=new CrossValSplitter<Example>(5);
int crossValSplits=5;
/** If true, adjust all class confidences by passing them thru a logistic function */
public boolean getUseLogisticOnConfidences(){
return useLogistic;
}
public void setUseLogisticOnConfidences(boolean flag){
useLogistic=flag;
}
/** If true, use confidence in class predictions as weight for that feature. */
public boolean getUseConfidences(){
return useConfidence;
}
public void setUseConfidences(boolean flag){
useConfidence=flag;
}
/** If true, adjust all class confidences by passing them thru a logistic function */
public boolean getUseTargetPrediction(){
return useTargetPrediction;
}
public void setUseTargetPrediction(boolean flag){
useTargetPrediction=flag;
}
/** Number of iterations of stacking to use */
public int getStackingDepth(){
return stackingDepth;
}
public void setStackingDepth(int newStackingDepth){
this.stackingDepth=newStackingDepth;
}
/* Number of cross-validation splits to use in making predictions */
public int getCrossValSplits(){
return crossValSplits;
}
public void setCrossValSplits(int newCrossValSplits){
this.splitter=new CrossValSplitter<Example>(newCrossValSplits);
crossValSplits=newCrossValSplits;
}
}
public StackingParams getParams(){
return params;
}
public StackedGraphicalLearner(){
this.baseLearner=new MaxEntLearner();
this.params=new StackingParams();
}
public StackedGraphicalLearner(BatchClassifierLearner baseLearner){
this();
this.baseLearner=baseLearner;
params.setStackingDepth(1);
}
public StackedGraphicalLearner(BatchClassifierLearner baseLearner,int depth){
this();
this.baseLearner=baseLearner;
params.setStackingDepth(depth);
}
public StackedGraphicalLearner(int depth){
this();
params.setStackingDepth(depth);
}
@Override
final public void setSchema(ExampleSchema schema){
this.schema=schema;
}
@Override
final public ExampleSchema getSchema(){
return schema;
}
@Override
public Classifier batchTrain(RealRelationalDataset dataset){
Classifier[] m=new Classifier[params.stackingDepth+1];
RealRelationalDataset stackedDataset=dataset;
ProgressCounter pc=
new ProgressCounter("training stacked learner","stacking level",
params.stackingDepth+1);
for(int d=0;d<=params.stackingDepth;d++){
m[d]=new DatasetClassifierTeacher(stackedDataset).train(baseLearner);
if(d+1<=params.stackingDepth){
stackedDataset=stackDataset(stackedDataset);
//new ViewerFrame("Dataset "+(d+1),new SmartVanillaViewer(stackedDataset));
}
pc.progress();
}
pc.finished();
return new StackedGraphicalClassifier(m,params,dataset);
}
/**
* Create a new dataset in which each instance has been augmented
* with the features constructed from the *predicted* labels
* of neighbor examples, where the prediction is made using
* cross-validation.
*/
public RealRelationalDataset stackDataset(RealRelationalDataset dataset){
RealRelationalDataset result=new RealRelationalDataset();
RealRelationalDataset.Split s=dataset.split(params.splitter);
//System.out.println("Stack Splitter: "+params.splitter);
schema=dataset.getSchema();
ProgressCounter pc=new ProgressCounter("stack-labeling","fold",s.getNumPartitions());
Map<String,ClassLabel> rlt=new HashMap<String,ClassLabel>();
for(int k=0;k<s.getNumPartitions();k++){
RealRelationalDataset trainData=(RealRelationalDataset)s.getTrain(k);
RealRelationalDataset testData=(RealRelationalDataset)s.getTest(k);
log.info("splitting with "+params.splitter+", preparing to train on "+
trainData.size()+" and test on "+testData.size());
Classifier c=new DatasetClassifierTeacher(trainData).train(baseLearner);
for(Iterator<Example> i=testData.iterator();i.hasNext();){
SGMExample ex=(SGMExample)i.next();
ClassLabel p=c.classification(ex);
rlt.put(ex.getExampleID(),p);
}
log.info("splitting with "+params.splitter+", stored classified dataset");
pc.progress();
} //get cv-like predictions for all training examples
Map<String,Map<String,Set<String>>> LinksMap=
CoreRelationalDataset.getLinksMap();
Map<String,Set<String>> Aggregators=RealRelationalDataset.getAggregators();
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
SGMExample ex=(SGMExample)i.next();
SGMExample AugmentEx=AugmentExample(ex,LinksMap,Aggregators,rlt);
result.add(AugmentEx);
}
pc.finished();
return result;
}
private SGMExample AugmentExample(SGMExample ex,Map<String,Map<String,Set<String>>> LinksMap,
Map<String,Set<String>> Aggregators,Map<String,ClassLabel> PredictedRlt){
int numNewFeatures=0;
for(Iterator<String> iter=Aggregators.keySet().iterator();iter.hasNext();){
numNewFeatures=
numNewFeatures+Aggregators.get(iter.next()).size()*
schema.getNumberOfClasses();
}
String[] features=new String[numNewFeatures];
double[] values=new double[numNewFeatures];
int index=0;
String egID=ex.getExampleID();
if(LinksMap.containsKey(egID)){ //this obj has ngbs
//LinksMap.get(from) is a hashMap, keys are type, final val is to
// Aggregators is HashMap, key: type, val: operation--a hashset
for(Iterator<String> iter=Aggregators.keySet().iterator();iter.hasNext();){ //for all types
String type=iter.next();
if(LinksMap.get(egID).containsKey(type)){
Set<String> oper=Aggregators.get(type);
for(Iterator<String> operIter=oper.iterator();operIter.hasNext();){ //every operations
String Agr=operIter.next();
int[] temval=new int[schema.getNumberOfClasses()];
Set<String> ngb=LinksMap.get(egID).get(type);
for(Iterator<String> ngbiter=ngb.iterator();ngbiter.hasNext();){
String ngbID=ngbiter.next();
if(PredictedRlt.get(ngbID)!=null){
String pre=PredictedRlt.get(ngbID).bestClassName();
int idx=schema.getClassIndex(pre);
temval[idx]++;
}
}
for(int i=0;i<schema.getNumberOfClasses();i++){
features[index]=stackFeatureName(type,Agr,schema.getClassName(i));
if(Agr.equals("COUNT"))
values[index]=temval[i];
if(Agr.equals("EXISTS")&&temval[i]>0)
values[index]=1;
index++;
}
}
}
}//end Aggregators.keySet
// System.out.println(" index is "+index+" features are "+features.length+" and values are "+values.length);
String[] truefeatures=new String[index];
double[] truevalues=new double[index];
for(int i=0;i<index;i++){
truefeatures[i]=features[i];
truevalues[i]=values[i];
}
Instance stackedInstance=
new AugmentedInstance(ex.asInstance(),truefeatures,truevalues);
return new SGMExample(stackedInstance,ex.getLabel(),ex.getExampleID());
}else{
return ex;
}
}
private static String stackFeatureName(String agr,String type,
String predictedClassName){
return "pred."+agr+"."+type+"."+predictedClassName;
}
public class StackedGraphicalClassifier implements Classifier,Visible{
private Classifier[] m;
//private RealRelationalDataset dataset;
private StackingParams params;
public StackedGraphicalClassifier(Classifier[] m,StackingParams params,
RealRelationalDataset ds){
this.m=m;
this.params=params;
//this.dataset=ds;
}
@Override
public ClassLabel classification(Instance instance){
return m[0].classification(instance);
}
public Map<String,ClassLabel> classification(RealRelationalDataset dataset){
Map<String,ClassLabel> rlt=new HashMap<String,ClassLabel>();
RealRelationalDataset testData=dataset;
for(int d=0;d<=params.stackingDepth;d++){
for(Iterator<Example> i=testData.iterator();i.hasNext();){
SGMExample ex=(SGMExample)i.next();
ClassLabel p=m[d].classification(ex);
rlt.put(ex.getExampleID(),p);
}
if(d+1<=params.stackingDepth){
testData=stackTestDataset(testData,rlt);
}
}
return rlt;
}
public RealRelationalDataset stackTestDataset(
RealRelationalDataset dataset,Map<String,ClassLabel> predictions){
RealRelationalDataset result=new RealRelationalDataset();
Map<String,Map<String,Set<String>>> LinksMap=
CoreRelationalDataset.getLinksMap();
Map<String,Set<String>> Aggregators=
RealRelationalDataset.getAggregators();
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
SGMExample ex=(SGMExample)i.next();
SGMExample AugmentEx=
AugmentExample(ex,LinksMap,Aggregators,predictions);
result.addSGM(AugmentEx);
}
return result;
}
public double score(Instance instance,String classLabelName){
return classification(instance).getWeight(classLabelName);
}
@Override
public String explain(Instance instance){
return "sorry, not implemented yet";
}
@Override
public Explanation getExplanation(Instance instance){
Explanation ex=new Explanation(explain(instance));
return ex;
}
@Override
public Viewer toGUI(){
ParallelViewer v=new ParallelViewer();
for(int i=0;i<m.length;i++){
final int k=i;
v.addSubView("Level "+k+" classifier",new TransformedViewer(
new SmartVanillaViewer(m[k])){
static final long serialVersionUID=20080202L;
@Override
public Object transform(Object o){
StackedGraphicalClassifier s=(StackedGraphicalClassifier)o;
return s.m[k];
}
});
}
v.setContent(this);
return v;
}
}
}