/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.semisupervised; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import javax.swing.ButtonGroup; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JRadioButton; import javax.swing.JScrollPane; import javax.swing.JTable; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.Explanation; import edu.cmu.minorthird.classify.Feature; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.WeightedSet; import edu.cmu.minorthird.util.MathUtil; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.Controllable; import edu.cmu.minorthird.util.gui.ControlledViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerControls; import edu.cmu.minorthird.util.gui.Visible; import gnu.trove.iterator.TObjectDoubleIterator; import gnu.trove.map.hash.TObjectDoubleHashMap; /** * @author Edoardo Airoldi * Date: Mar 15, 2004 */ public class MultinomialClassifier implements SemiSupervisedClassifier, Classifier,Visible,Serializable{ static final long serialVersionUID=20080207L; static Logger log=Logger.getLogger(MultinomialClassifier.class); private double SCALE; // set by learner if needed private List<String> classNames; private List<Double> classParameters; private Map<Feature,String> featureModels; private List<WeightedSet<Feature>> featureGivenClassParameters; private double featurePrior; private String unseenModel; // constructor public MultinomialClassifier(){ this.classNames=new ArrayList<String>(); this.classParameters=new ArrayList<Double>(); this.featureModels=new HashMap<Feature,String>(); this.featureGivenClassParameters=new ArrayList<WeightedSet<Feature>>(); this.featureGivenClassParameters.add(new WeightedSet<Feature>()); this.featurePrior=0.0; this.unseenModel=null; } // // methods in Classifier interface // @Override public ClassLabel classification(Instance instance){ double[] score=score(instance); // implement smoothing for *unseen* features //System.out.println("size="+score.length); int maxIndex=0; for(int i=0;i<score.length;i++){ //System.out.println("i="+i+" score="+score[i]); if(score[i]>score[maxIndex]){ maxIndex=i; } } //System.out.println( classNames.get(0)+","+score[0]+" "+classNames.get(1)+","+score[1]); return new ClassLabel(classNames.get(maxIndex)); } public double[] score(Instance instance){ //System.out.println(instance); //System.out.println( "class="+classNames.get(0)+" counts="+featureGivenClassParameters.get(0) ); //System.out.println( "class="+classNames.get(1)+" counts="+featureGivenClassParameters.get(1) ); double total=0.0; for(Iterator<Feature> j=instance.featureIterator();j.hasNext();){ Feature f=j.next(); total+=instance.getWeight(f); } double[] score=new double[classNames.size()]; for(int i=0;i<classNames.size();i++){ score[i]=0.0; //System.out.println("instance="+instance); for(Iterator<Feature> j=instance.featureIterator();j.hasNext();){ Feature f=j.next(); double featureCounts=instance.getWeight(f); double featureProb= featureGivenClassParameters.get(i).getWeight(f); //System.out.println("feature="+f+" counts="+featureCounts+" prob="+featureProb+" class="+classProb); String model=getFeatureModel(f); //System.out.println("feature="+f+" model="+model); if(model.equals("Poisson")){ score[i]+= -featureProb*total/SCALE+featureCounts*Math.log(featureProb); }else if(model.equals("Binomial")){ score[i]+=featureCounts*Math.log(featureProb); }else if(model.equals("unseen")){ score[i]+=0.0; }else{ System.out.println("error: model "+model+" not found!"); System.exit(1); } } double classProb=(classParameters.get(i)).doubleValue(); score[i]+=Math.log(classProb); } return score; } @Override public String explain(Instance instance){ StringBuffer buf=new StringBuffer(""); for(Iterator<Feature> j=instance.featureIterator();j.hasNext();){ // Feature f=j.next(); if(buf.length()>0) buf.append("\n + "); else buf.append(" "); //buf.append( f+"<"+instance.getWeight(f)+"*"+featureScore(f)+">"); } //buf.append( "\n + bias<"+featureScore( BIAS_TERM )+">" ); buf.append("\n = "+score(instance)); return buf.toString(); } @Override public Explanation getExplanation(Instance instance){ Explanation.Node top= new Explanation.Node("MultinomialClassifier Explanation"); Explanation.Node features=new Explanation.Node("Features"); for(Iterator<Feature> j=instance.featureIterator();j.hasNext();){ Feature f=j.next(); Explanation.Node featureEx= new Explanation.Node(f+"<"+instance.getWeight(f)); features.add(featureEx); } Explanation.Node bias=new Explanation.Node("bias"); features.add(bias); top.add(features); Explanation.Node score=new Explanation.Node("\n = "+score(instance)); top.add(score); Explanation ex=new Explanation(top); return ex; } // // Get, Set, Check // public void setScale(double value){ this.SCALE=value; } public double getPrior(){ return featurePrior; } public void setPrior(double pi){ this.featurePrior=pi; } public String getUnseenModel(){ return unseenModel; } public void setUnseenModel(String str){ this.unseenModel=str; } public double getLogLikelihood(Example example){ //System.out.println( example ); int idx=-1; for(int i=0;i<classNames.size();i++){ if(classNames.get(i).equals(example.getLabel().bestClassName())){ idx=i; break; } } //System.out.println( "class="+classNames.get(idx) ); Instance instance=example.asInstance(); double loglik=0.0; //System.out.println("instance="+instance); for(Iterator<Feature> j=instance.featureIterator();j.hasNext();){ Feature f=j.next(); double featureCounts=instance.getWeight(f); double featureProb=featureGivenClassParameters.get(idx).getWeight(f); // double classProb=((Double)classParameters.get(idx)).doubleValue(); //System.out.println("feature="+f+" counts="+featureCounts+" prob="+featureProb+" class="+classProb); String model=getFeatureModel(f); if(model.equals("Poisson")){ loglik+=-featureProb+featureCounts*Math.log(featureProb); }else if(model.equals("Binomial")){ loglik+=featureCounts*Math.log(featureProb); }else if(model.equals("unseen")){ System.out.println("unseen: "+f); }else{ System.out.println("error: model "+model+" not found!"); System.exit(1); } } return loglik; } public void reset(){ this.classParameters=new ArrayList<Double>(); this.featureGivenClassParameters=new ArrayList<WeightedSet<Feature>>(); //this.featureGivenClassParameters.add( new WeightedSet() ); } public boolean isPresent(ClassLabel label){ boolean isPresent=false; for(int i=0;i<classNames.size();i++){ if(classNames.get(i).equals(label.bestClassName())){ isPresent=true; } } return isPresent; } public void addValidLabel(ClassLabel label){ classNames.add(label.bestClassName()); } public ClassLabel getLabel(int i){ return new ClassLabel(classNames.get(i)); } public int indexOf(ClassLabel label){ return classNames.indexOf(label.bestClassName()); } public void setFeatureGivenClassParameter(Feature f,int j, double probabilityOfOccurrence){ WeightedSet<Feature> wset; try{ wset=featureGivenClassParameters.get(j); wset.add(f,probabilityOfOccurrence); featureGivenClassParameters.set(j,wset); }catch(Exception t){ wset=null; wset=new WeightedSet<Feature>(); wset.add(f,probabilityOfOccurrence); featureGivenClassParameters.add(j,wset); } } public void setClassParameter(int j,double probabilityOfOccurrence){ classParameters.add(j,new Double(probabilityOfOccurrence)); } public void setFeatureModel(Feature feature,String model){ featureModels.put(feature,model); } public String getFeatureModel(Feature feature){ try{ String model=featureModels.get(feature).toString(); return model; }catch(NullPointerException x){ return "unseen"; } } public Iterator<Feature> featureIterator(){ // 1. create a new WeightedSet with all features TObjectDoubleHashMap map=new TObjectDoubleHashMap(); for(int i=0;i<classNames.size();i++){ WeightedSet<Feature> wset=featureGivenClassParameters.get(i); for(Iterator<Feature> j=wset.iterator();j.hasNext();){ Feature f=j.next(); double w=wset.getWeight(f); map.put(f,w); } } // 2. create global feature iterator final TObjectDoubleIterator ti=map.iterator(); Iterator<Feature> i=new Iterator<Feature>(){ @Override public boolean hasNext(){ return ti.hasNext(); } @Override public Feature next(){ ti.advance(); return (Feature)ti.key(); } @Override public void remove(){ ti.remove(); } }; return i; } public Object[] keys(){ TObjectDoubleHashMap map=new TObjectDoubleHashMap(); for(int i=0;i<classNames.size();i++){ WeightedSet<Feature> wset=featureGivenClassParameters.get(i); for(Iterator<Feature> j=wset.iterator();j.hasNext();){ Feature f=j.next(); double w=wset.getWeight(f); map.put(f,w); } } return map.keys(); } // // GUI related stuff // @Override public Viewer toGUI(){ Viewer gui= new ControlledViewer(new MyViewer(),new MultinomialClassifierControls()); gui.setContent(this); return gui; } static private class MultinomialClassifierControls extends ViewerControls{ static final long serialVersionUID=20080207L; // how to sort // private JRadioButton absoluteValueButton; private JRadioButton valueButton; private JRadioButton nameButton; // private JRadioButton noneButton; @Override public void initialize(){ add(new JLabel("Sort by")); ButtonGroup group=new ButtonGroup(); ; nameButton=addButton("name",group,true); valueButton=addButton("weight",group,false); // absoluteValueButton=addButton("|weight|",group,false); } private JRadioButton addButton(String s,ButtonGroup group,boolean selected){ JRadioButton button=new JRadioButton(s,selected); group.add(button); add(button); button.addActionListener(this); return button; } } static private class MyViewer extends ComponentViewer implements Controllable{ static final long serialVersionUID=20080207L; private MultinomialClassifierControls controls=null; private MultinomialClassifier h=null; @Override public void applyControls(ViewerControls controls){ this.controls=(MultinomialClassifierControls)controls; setContent(h,true); revalidate(); } @Override public boolean canReceive(Object o){ return o instanceof MultinomialClassifier; } @Override public JComponent componentFor(Object o){ h=(MultinomialClassifier)o; Object[] keys=h.keys(); Object[][] tableData=new Object[keys.length][(h.classNames.size()+1)]; int k=0; for(Iterator<Feature> i=h.featureIterator();i.hasNext();){ Feature f=i.next(); tableData[k][0]=f; for(int l=0;l<h.classNames.size();l++){ tableData[k][(l+1)]= new Double(h.featureGivenClassParameters.get(l).getWeight(f)); ; } k++; } if(controls!=null){ Arrays.sort(tableData,new Comparator<Object[]>(){ @Override public int compare(Object[] ra,Object[] rb){ if(controls.nameButton.isSelected()) return ra[0].toString().compareTo(rb[0].toString()); Double da=(Double)ra[1]; Double db=(Double)rb[1]; if(controls.valueButton.isSelected()) return MathUtil.sign(db.doubleValue()-da.doubleValue()); else return MathUtil.sign(Math.abs(db.doubleValue())- Math.abs(da.doubleValue())); } }); } String[] columnNames=new String[(h.classNames.size()+1)]; columnNames[0]="Feature Name"; for(int i=0;i<h.classNames.size();i++){ columnNames[(i+1)]="Wgt "+h.classNames.get(i); } JTable table=new JTable(tableData,columnNames); monitorSelections(table,0); return new JScrollPane(table); } } @Override public String toString(){ return null; } }