/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.nlp; import java.io.Serializable; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; /*** * Text classifier and feature extractor using Kullback-Leibler Distance. * * See "Using Kullback-Leibler Distance for Text Categorization Brigitte Bigi" * * @author pmerienne * */ public class KLDClassifier implements TextClassifier<Integer>, TextFeaturesExtractor, Serializable { private static final long serialVersionUID = 3869875629653284342L; private int maxWordsPerClass = 500000; private double thresholdFactor = 10.0; private boolean normalize = true; private List<Vocabulary> classVocabularies = new ArrayList<Vocabulary>(); private List<Double> gammas = new ArrayList<Double>(); private Double espilon = null; public KLDClassifier() { } public KLDClassifier(int nbClasses) { this(nbClasses, 500000, 10.0, true); } public KLDClassifier(int nbClasses, int maxWordsPerClass) { this(nbClasses, maxWordsPerClass, 10.0, true); } public KLDClassifier(int nbClasses, int maxWordsPerClass, double thresholdFactor, boolean normalize) { this.maxWordsPerClass = maxWordsPerClass; this.thresholdFactor = thresholdFactor; this.normalize = normalize; for (int i = 0; i < nbClasses; i++) { this.classVocabularies.add(new Vocabulary()); this.gammas.add(null); } } @Override public double[] extractFeatures(List<String> documentWords) { Vocabulary documentVocabulary = new Vocabulary(documentWords); Set<String> vocabulary = this.createGlobalVocabulary(); int vocabularySize = vocabulary.size(); int nbClasses = this.classVocabularies.size(); double[] features = new double[vocabularySize * nbClasses]; double beta = this.caculateBeta(documentVocabulary); Double tpd, tpc; int i = 0; for (String word : vocabulary) { tpd = this.wordProbabilityInDocument(word, documentVocabulary, beta); for (int j = 0; j < nbClasses; j++) { tpc = this.wordProbabilityInCategory(word, 0); features[j * vocabularySize + i] = (tpc - tpd) * Math.log(tpc / tpd); } i++; } return features; } @Override public void update(Integer classIndex, List<String> documentWords) { // Update class vocabulary Vocabulary classVocabulary = this.classVocabularies.get(classIndex); classVocabulary.addAll(documentWords); classVocabulary.limitWords(this.maxWordsPerClass); // Reset associated gamma this.gammas.set(classIndex, null); // Reset epsilon this.espilon = null; } @Override public Integer classify(List<String> documentWords) { int classIndex = -1; double[] distances = this.distance(documentWords); // Find minimum distance double minDistance = Double.POSITIVE_INFINITY; int i = 0; for (double distance : distances) { if (distance < minDistance) { minDistance = distance; classIndex = i; } i++; } return classIndex; } public double[] distance(List<String> documentWords) { double[] distance = new double[this.classVocabularies.size()]; Vocabulary documentVocabulary = new Vocabulary(documentWords); double beta = this.caculateBeta(documentVocabulary); double betaZero = this.caculateBeta(new Vocabulary()); for (int classIndex = 0; classIndex < this.classVocabularies.size(); classIndex++) { distance[classIndex] = this.distance(documentVocabulary, classIndex, beta); if (this.normalize) { distance[classIndex] /= this.distance(new Vocabulary(), classIndex, betaZero); } } return distance; } protected Double distance(Vocabulary documentVocabulary, int classIndex, double beta) { Double distance = 0.0; Set<String> vocabulary = this.createGlobalVocabulary(); Double tpc; Double tpd; for (String word : vocabulary) { tpc = this.wordProbabilityInCategory(word, classIndex); tpd = this.wordProbabilityInDocument(word, documentVocabulary, beta); distance += (tpc - tpd) * Math.log(tpc / tpd); } return distance; } protected Double wordProbabilityInCategory(String word, int classIndex) { Vocabulary classVocabulary = this.classVocabularies.get(classIndex); Double probability = classVocabulary.frequency(word); if (probability == 0 || probability.equals(Double.NaN)) { probability = this.estimateEpsilon(); } else { probability *= this.getGamma(classIndex); } return probability; } protected Double wordProbabilityInDocument(String word, Vocabulary documentVocabulary, double beta) { Double probability = documentVocabulary.frequency(word); if (probability == 0 || probability.equals(Double.NaN)) { probability = this.getEpsilon(); } else { probability *= beta; } return probability; } protected double getEpsilon() { if (this.espilon == null) { this.espilon = this.estimateEpsilon(); } return this.espilon; } protected double estimateEpsilon() { Integer maxSize = 0; Integer candidate; for (Vocabulary vocabulary : this.classVocabularies) { candidate = vocabulary.totalCount(); if (candidate > maxSize) { maxSize = candidate; } } return 1 / (this.thresholdFactor * maxSize.doubleValue()); } protected double getGamma(int classIndex) { Double gamma = this.gammas.get(classIndex); if (gamma == null) { gamma = this.calculateGamma(classIndex); this.gammas.set(classIndex, gamma); } return gamma; } protected double calculateGamma(int classIndex) { Double gamma = 1.0; Double epsilon = this.getEpsilon(); Vocabulary classVocabulary = this.classVocabularies.get(classIndex); Set<String> globalVocabulary = this.createGlobalVocabulary(); for (String word : globalVocabulary) { if (!classVocabulary.contains(word)) { gamma -= epsilon; } } return gamma; } protected double caculateBeta(Vocabulary documentVocabulary) { Double beta = 1.0; Double epsilon = this.getEpsilon(); Set<String> globalVocabulary = this.createGlobalVocabulary(); for (String word : globalVocabulary) { if (!documentVocabulary.contains(word)) { beta -= epsilon; } } return beta; } private Set<String> createGlobalVocabulary() { Set<String> vocabulary = new HashSet<String>(); for (Vocabulary classVocabulary : this.classVocabularies) { vocabulary.addAll(classVocabulary.wordSet()); } return vocabulary; } }