package edu.cmu.minorthird.text.model; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.LineNumberReader; import java.io.PrintStream; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import edu.cmu.minorthird.text.BasicTextBase; import edu.cmu.minorthird.text.FancyLoader; import edu.cmu.minorthird.text.Span; import edu.cmu.minorthird.text.TextBase; /** Unigram Language Model * * @author William Cohen */ public class UnigramModel{ // avoid re-creating a zillion copies of Double(1), etc final private static Double[] CACHED_DOUBLES=new Double[10]; static{ for(int i=0;i<CACHED_DOUBLES.length;i++) CACHED_DOUBLES[i]=new Double(i); } // // private data // private Map<String,Double> freq=new HashMap<String,Double>(); private double total=0; /** Load a file where each line contains a <count,word> pair. */ public void load(File file) throws IOException,FileNotFoundException{ LineNumberReader in=new LineNumberReader(new FileReader(file)); String line; while((line=in.readLine())!=null){ String[] words=line.trim().split("\\s+"); if(words.length!=2) badLine(line,in); int n=0; try{ n=Integer.parseInt(words[0]); }catch(NumberFormatException e){ badLine(line,in); } total+=n; freq.put(words[1],getDouble(n)); } in.close(); } private void badLine(String line,LineNumberReader in){ throw new IllegalStateException("bad input at line "+in.getLineNumber()+ ": "+line); } /** Save a unigram model */ public void save(File file) throws IOException{ PrintStream out= new PrintStream(new BufferedOutputStream(new FileOutputStream(file))); for(Iterator<Map.Entry<String,Double>> i=freq.entrySet().iterator();i .hasNext();){ Map.Entry<String,Double> e=i.next(); out.println(e.getValue().intValue()+" "+e.getKey()); } out.close(); } // routine to use cached doubles, rather than cons up new doubles private Double getDouble(int n){ if(n<CACHED_DOUBLES.length) return CACHED_DOUBLES[n]; else return new Double(n); } /** Return log Prob(span|model). * Assuming indendence, this is sum log Prob(tok_i|model). */ public double score(Span span){ double sum=0; double prior=0.1/total; // lower than any word we've seen for(int i=0;i<span.size();i++){ int f=getFrequency(span.getToken(i).getValue().toLowerCase()); sum+=estimatedLogProb(f,total,prior,1.0); } return sum; } public double getTotalWordCount(){ return total; } public int getFrequency(String s){ String s1=s.toLowerCase(); Double f=freq.get(s1); if(f==null) return 0; else return f.intValue(); } public void incrementFrequency(String s){ String s1=s.toLowerCase(); freq.put(s1,getDouble(getFrequency(s1)+1)); } private double estimatedLogProb(double k,double n,double prior, double pseudoCounts){ return Math.log((k+prior*pseudoCounts)/(n+pseudoCounts)); } static public void main(String[] args) throws IOException{ if(args.length==0){ System.out.println("usage 1: modelfile span1 span2..."); System.out.println("usage 2: textbase modelfile"); } if(args.length==2){ UnigramModel model=new UnigramModel(); TextBase base=FancyLoader.loadTextLabels(args[0]).getTextBase(); for(Iterator<Span> i=base.documentSpanIterator();i.hasNext();){ Span s=i.next(); for(int j=0;j<s.size();j++){ model.incrementFrequency(s.getToken(j).getValue()); } } model.save(new File(args[1])); }else{ UnigramModel model=new UnigramModel(); model.load(new File(args[0])); BasicTextBase base=new BasicTextBase(); for(int i=1;i<args.length;i++){ base.loadDocument("argv."+i,args[i]); } for(Iterator<Span> j=base.documentSpanIterator();j.hasNext();){ Span s=j.next(); System.out.println(s.asString()+" => "+model.score(s)); for(int k=0;k<s.size();k++){ String w=s.getToken(k).getValue(); System.out.print(" "+w+":"+model.getFrequency(w)); } System.out.println(); } } } }