/*
* Encog(tm) Java Examples v3.4
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-examples
*
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.examples.ml.bayesian.words;
import java.util.ArrayList;
import java.util.List;
import org.encog.mathutil.probability.CalcProbability;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.EventType;
import org.encog.ml.bayesian.query.enumerate.EnumerationQuery;
import org.encog.util.Format;
import org.encog.util.text.BagOfWords;
public class BayesianWordAnalyzer {
private int k;
private BagOfWords classBag;
private BagOfWords notClassBag;
private BagOfWords totalBag;
private final String className;
private final String notClassName;
private CalcProbability messageProbability;
private String lastProblem;
private final int classSampleCount;
private final int notClassSampleCount;
public BayesianWordAnalyzer(
int theK,
String theClassName,
String[] classStrings,
String theNotClassName,
String[] notClassStrings)
{
this.k = theK;
this.className = theClassName;
this.notClassName = theNotClassName;
this.classSampleCount = classStrings.length;
this.notClassSampleCount = notClassStrings.length;
this.classBag = new BagOfWords(this.k);
this.notClassBag = new BagOfWords(this.k);
this.totalBag = new BagOfWords(this.k);
for(String line: classStrings) {
this.classBag.process(line);
totalBag.process(line);
}
for(String line: notClassStrings) {
this.notClassBag.process(line);
totalBag.process(line);
}
this.classBag.setLaplaceClasses(totalBag.getUniqueWords());
this.notClassBag.setLaplaceClasses(totalBag.getUniqueWords());
this.messageProbability = new CalcProbability(this.k);
messageProbability.addClass(this.classSampleCount);
messageProbability.addClass(this.notClassSampleCount);
}
public List<String> separateSpaces(String str) {
List<String> result = new ArrayList<String>();
StringBuilder word = new StringBuilder();
for (int i = 0; i < str.length(); i++) {
char ch = str.charAt(i);
if (ch != '\'' && !Character.isLetterOrDigit(ch)) {
if (word.length() > 0) {
result.add(word.toString());
word.setLength(0);
}
} else {
word.append(ch);
}
}
if (word.length() > 0) {
result.add(word.toString());
}
return result;
}
public double probability(String m) {
List<String> words = separateSpaces(m);
BayesianNetwork network = new BayesianNetwork();
BayesianEvent spamEvent = network.createEvent(this.className);
int index = 0;
for( String word: words) {
BayesianEvent event = network.createEvent(word+index);
network.createDependency(spamEvent, event);
index++;
}
network.finalizeStructure();
//SamplingQuery query = new SamplingQuery(network);
EnumerationQuery query = new EnumerationQuery(network);
double probSpam = messageProbability.calculate(0);
spamEvent.getTable().addLine(probSpam, true);
query.defineEventType(spamEvent, EventType.Outcome);
query.setEventValue(spamEvent, true);
index = 0;
for( String word: words) {
String word2 = word+index;
BayesianEvent event = network.getEvent(word2);
event.getTable().addLine(this.classBag.probability(word), true, true); // spam
event.getTable().addLine(this.notClassBag.probability(word), true, false); // ham
query.defineEventType(event, EventType.Evidence);
query.setEventValue(event, true);
index++;
}
//query.setSampleSize(100000000);
query.execute();
this.lastProblem = query.getProblem();
//System.out.println(query.getProblem());
return query.getProbability();
}
/**
* @return the className
*/
public String getClassName() {
return className;
}
/**
* @return the notClassName
*/
public String getNotClassName() {
return notClassName;
}
public double getClassProbability() {
return this.messageProbability.calculate(0);
}
public double getNotClassProbability() {
return this.messageProbability.calculate(1);
}
public double probabilityWordClass(String word) {
StringBuilder s = new StringBuilder();
s.append("P(");
s.append(word);
s.append("|");
s.append(this.className);
s.append(")");
this.lastProblem = s.toString();
return this.classBag.probability(word);
}
public double probabilityWordNotClass(String word) {
StringBuilder s = new StringBuilder();
s.append("P(");
s.append(word);
s.append("|");
s.append(this.notClassName);
s.append(")");
this.lastProblem = s.toString();
return this.notClassBag.probability(word);
}
public String getLastProblem() {
return this.lastProblem;
}
}