package edu.cmu.minorthird.classify.semisupervised;
import java.util.Iterator;
import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SampleDatasets;
/**
* Implementation of the methods described in:
* K. Nigam, A. McCallum, S. Thrun and T. Mitchell. Text Classifiaction from
* labeled and unlabeled documents using EM. W. Choen editor, Machine Learning,
* 1999.
*
* @author Edoardo Airoldi
* Date: Mar 13, 2004
*/
public class SemiSupervisedNaiveBayesLearner extends
SemiSupervisedBatchClassifierLearner{
private int MAX_ITER=1000;
private Iterator<Instance> iteratorOverUnlabeled;
// constructors
public SemiSupervisedNaiveBayesLearner(){
;
}
public SemiSupervisedNaiveBayesLearner(int iterations){
this.MAX_ITER=iterations;
}
@Override
public void setSchema(ExampleSchema schema){
;
}
@Override
public void setInstancePool(Iterator<Instance> i){
this.iteratorOverUnlabeled=i;
}
@Override
public ExampleSchema getSchema(){
return null;
}
@Override
public ClassifierLearner copy(){
ClassifierLearner learner=null;
try{
learner=(ClassifierLearner)this.clone();
learner.reset();
}catch(Exception e){
e.printStackTrace();
}
return learner;
}
@Override
public Classifier batchTrain(SemiSupervisedDataset dataset){
Classifier mc=new MultinomialClassifier();
//System.out.println(dataset);
// 1. retrieve valid class names
int numberOfClasses=0;
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
Example ex=i.next();
if(!((MultinomialClassifier)mc).isPresent(ex.getLabel())){
((MultinomialClassifier)mc).addValidLabel(ex.getLabel());
numberOfClasses+=1;
}
}
// 2. estimate parameters using labeled examples
BasicFeatureIndex index=new BasicFeatureIndex(dataset);
double[] countsGivenClass=new double[numberOfClasses];
double[] examplesGivenClass=new double[numberOfClasses];
;
// 2.1. toStal counts
double numberOfExamples=(dataset.size());
double numberOfFeatures=(index.numberOfFeatures());
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
Example ex=i.next();
int classIndex=((MultinomialClassifier)mc).indexOf(ex.getLabel());
//System.out.println("cllassIndex="+classIndex);
examplesGivenClass[classIndex]+=1.0;
for(Iterator<Feature> j=index.featureIterator();j.hasNext();){
Feature f=j.next();
countsGivenClass[classIndex]+=ex.getWeight(f);
}
}
for(int j=0;j<numberOfClasses;j++){
//System.out.println("classes="+numberOfClasses+" ex|class="+examplesGivenClass[j]+" examples="+numberOfExamples);
double probabilityOfOccurrence=
estimateClassProbMLE(1.0,numberOfClasses,
examplesGivenClass[j],numberOfExamples);
((MultinomialClassifier)mc).setClassParameter(j,probabilityOfOccurrence);
//System.out.println("classP="+probabilityOfOccurrence);
}
// 2.2. loop features
for(Iterator<Feature> i=index.featureIterator();i.hasNext();){
Feature f=i.next();
// 2.2.1. retrieve counts by feature
double[] countsFeatureGivenClass=new double[numberOfClasses];
for(int j=0;j<index.size(f);j++){
Example ex=index.getExample(f,j);
int classIndex=((MultinomialClassifier)mc).indexOf(ex.getLabel());
countsFeatureGivenClass[classIndex]+=ex.getWeight(f);
}
// 2.2.2. estimate parameters of MultinomialClassifier for a certain (feature,class)
for(int j=0;j<numberOfClasses;j++){
//System.out.println("feature="+f+" class="+j+" label="+((MultinomialClassifier)mc).getLabel(j));
//System.out.println("features="+numberOfFeatures+" fCnt|class="+countsFeatureGivenClass[j]+" cnt|class="+countsGivenClass[j]);
double probabilityOfOccurrence=
estimateFeatureProbMLE(1.0,numberOfFeatures,
countsFeatureGivenClass[j],countsGivenClass[j]);
((MultinomialClassifier)mc).setFeatureGivenClassParameter(f,j,
probabilityOfOccurrence);
//System.out.println("prob="+probabilityOfOccurrence);
}
((MultinomialClassifier)mc).setFeatureModel(f,"Binomial");
}
// 3. assign lables using classifier
Dataset unlabeledDataset=new BasicDataset();
Iterator<Instance> il=iteratorOverUnlabeled;
for(Iterator<Instance> i=il;i.hasNext();){
Instance mi=i.next();
System.out.println(mi);
ClassLabel estimatedClassLabel=mc.classification(mi);
unlabeledDataset.add(new Example(mi,estimatedClassLabel));
}
//System.out.println(unlabeledDataset);
// 4. initialize log-likelihood
double logLik=Double.NEGATIVE_INFINITY;
double previousLogLik;
// 5. loop until convergence
int iter=0;
boolean hasConverged=false;
while(iter<MAX_ITER&!hasConverged){
// 5.1. initialization
previousLogLik=logLik;
logLik=0.0;
//Example.Looper el = new Example.Looper( dataset.iterator() );
Dataset combinedDataset=new BasicDataset();
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
combinedDataset.add(i.next());
}
//el = new Example.Looper( unlabeledDataset.iterator() );
for(Iterator<Example> i=unlabeledDataset.iterator();i.hasNext();){
combinedDataset.add(i.next());
}
//System.out.println(combinedDataset);
// 5.2. estimates parameters using all examples
((MultinomialClassifier)mc).reset();
index=new BasicFeatureIndex(combinedDataset);
countsGivenClass=new double[numberOfClasses];
examplesGivenClass=new double[numberOfClasses];
;
// 5.2.1. toStal counts
numberOfExamples=(combinedDataset.size());
numberOfFeatures=(index.numberOfFeatures());
//el = new Example.Looper( dataset.iterator() );
for(Iterator<Example> i=dataset.iterator();i.hasNext();){
Example ex=i.next();
int classIndex=((MultinomialClassifier)mc).indexOf(ex.getLabel());
//System.out.println("cllassIndex="+classIndex);
examplesGivenClass[classIndex]+=1.0;
for(Iterator<Feature> j=index.featureIterator();j.hasNext();){
Feature f=j.next();
countsGivenClass[classIndex]+=ex.getWeight(f);
}
}
for(int j=0;j<numberOfClasses;j++){
//System.out.println("classes="+numberOfClasses+" ex|class="+examplesGivenClass[j]+" examples="+numberOfExamples);
double probabilityOfOccurrence=
estimateClassProbMLE(1.0,numberOfClasses,
examplesGivenClass[j],numberOfExamples);
((MultinomialClassifier)mc)
.setClassParameter(j,probabilityOfOccurrence);
//System.out.println("classP="+probabilityOfOccurrence);
}
// 5.2.2. loop features
for(Iterator<Feature> i=index.featureIterator();i.hasNext();){
Feature f=i.next();
// 5.2.2.1. retrieve counts by feature
double[] countsFeatureGivenClass=new double[numberOfClasses];
for(int j=0;j<index.size(f);j++){
Example ex=index.getExample(f,j);
int classIndex=((MultinomialClassifier)mc).indexOf(ex.getLabel());
countsFeatureGivenClass[classIndex]+=ex.getWeight(f);
}
// 5.5.2.2. estimate parameters of MultinomialClassifier for a certain (feature,class)
for(int j=0;j<numberOfClasses;j++){
//System.out.println("feature="+f+" class="+j+" label="+((MultinomialClassifier)mc).getLabel(j));
//System.out.println("features="+numberOfFeatures+" fCnt|class="+countsFeatureGivenClass[j]+" cnt|class="+countsGivenClass[j]);
double probabilityOfOccurrence=
estimateFeatureProbMLE(1.0,numberOfFeatures,
countsFeatureGivenClass[j],countsGivenClass[j]);
((MultinomialClassifier)mc).setFeatureGivenClassParameter(f,j,
probabilityOfOccurrence);
//System.out.println("prob="+probabilityOfOccurrence);
}
((MultinomialClassifier)mc).setFeatureModel(f,"Binomial");
}
// 5.3. re-assign labels using current value of parameters
//unlabeledDataset = new BasicDataset();
il=iteratorOverUnlabeled;
for(Iterator<Instance> i=il;i.hasNext();){
Instance mi=i.next();
System.out.println(mi);
ClassLabel estimatedClassLabel=mc.classification(mi);
unlabeledDataset.add(new Example(mi,estimatedClassLabel));
}
// 5.4. compute the log-lik of complete data
logLik=0.0;
for(Iterator<Example> eloo=combinedDataset.iterator();eloo.hasNext();){
Example example=eloo.next();
logLik+=((MultinomialClassifier)mc).getLogLikelihood(example);
}
// 5.5. check convergence and iterate
if(EMconverged(logLik,previousLogLik,1e-7,true)){
hasConverged=true;
System.out.println("EM converged!");
}else{
System.out.println("iteration="+(iter+1)+" log-likelihood="+logLik);
}
iter+=1;
}
// 6. return classifier
return mc;
}
//
// private methods
//
private double estimateClassProbMLE(double classPrior,double numberOfClasses,
double observedCounts,double totalCounts){
return (classPrior+observedCounts)/(numberOfClasses+totalCounts);
}
private double estimateFeatureProbMLE(double featurePrior,
double numberOfFeatures,double observedCounts,double totalCounts){
return (featurePrior+observedCounts)/(numberOfFeatures+totalCounts);
}
/* We have converged if the slope of the log-likelihood function falls below 'threshold',
* i.e., |f(t) - f(t-1)| / avg < threshold, where avg = (|f(t)| + |f(t-1)|)/2 and
* f(t) is log lik at iteration t. 'threshold' defaults to 1e-4.
*
* This stopping criterion is from Numerical Recipes in C p423
*
* Note: If we are doing MAP estimation (using priors), the likelihood can decrase, even
* though the mode of the posterior is increasing.
*/
private boolean EMconverged(double loglik,double previousLoglik,
double threshold,boolean checkIncreased){
double epsilon=2.2204e-16;
boolean converged=false;
if(checkIncreased){
if(loglik-previousLoglik<-1e-3) // allow for a little imprecision
{
System.out.println("******likelihood decreased from "+previousLoglik+
" to "+loglik);
}
}
double deltaLoglik=Math.abs(loglik-previousLoglik);
double avgLoglik=(Math.abs(loglik)+Math.abs(previousLoglik)+epsilon)/2;
if((deltaLoglik/avgLoglik)<threshold){
converged=true;
}
return converged;
}
//
// Test SemiSupervisedNaiveBayesLearner
//
static public void main(String[] args){
Dataset dataset=new BasicDataset();
/*try {
// load counts from file
File fileOfCounts = new File("/Users/eairoldi/cmu.research/minorthird/apps/unlabeledDataset.3rd");
dataset = DatasetLoader.loadFile(fileOfCounts);
} catch (Exception e) {
log.error(e, e);
System.exit(1);
}
System.out.println( "DatasetLoader:\n" + dataset );*/
dataset=SampleDatasets.sampleData("bayesUnlabeled",false);
System.out.println("SampleDatasets (bayesUnlabeled):\n"+dataset);
}
}