/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.chance.distribution.probability.dirichlet;
import org.drools.chance.degree.Degree;
import org.drools.chance.degree.interval.IntervalDegree;
import org.drools.chance.degree.simple.SimpleDegree;
import org.drools.chance.distribution.DiscreteProbabilityDistribution;
import org.drools.chance.core.util.ValueSortedMap;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
/**
* Discrete probability distribution
* TODO
* @param <T>
*/
public class DirichletDistribution<T> implements DiscreteProbabilityDistribution<T> {
private ValueSortedMap<T,Double> alphaWeightMap = new ValueSortedMap<T,Double>();
private double mass = 0;
private boolean normalized;
public Degree getDegree(T value) {
return new SimpleDegree(getExpectation(value));
}
public Degree get(T value) {
return getDegree( value );
}
public Number domainSize(){
return alphaWeightMap.size();
}
public int size() {
return alphaWeightMap.size();
}
DirichletDistribution() { }
/**
* inner accessor, to be used by strategies
* @return the value-weight map
*/
ValueSortedMap<T,Double> getAlphaWeights() {
return alphaWeightMap;
}
/**
* inner accessor, to be used by strategies
* @return the total weight
*/
double getMass() { return mass; }
/**
* inner setter, to be used by strategies
* @param m the new total weight
*/
void setMass(double m) { mass = m; }
public Set<T> getSupport() {
return alphaWeightMap.keySet();
}
public Iterator<T> iterator() {
return alphaWeightMap.keySet().iterator();
}
public double getExpectation(T value) {
if (! alphaWeightMap.containsKey(value)) return 0;
return mass > 0 ? (alphaWeightMap.get(value) / mass) : 0;
}
public double getVariance(T value) {
if (! alphaWeightMap.containsKey(value)) return 0;
double alpha = alphaWeightMap.get(value);
return mass > 0 ?
( (alpha*(mass-alpha)) / (mass*mass*(mass+1)) )
: 0;
}
public double getMode(T value) {
if (! alphaWeightMap.containsKey(value)) return 0;
double alpha = alphaWeightMap.get(value);
return mass > 0 ? ((alpha-1)/(mass-size())) : 0;
}
public double getCovariance(T value1, T value2) {
if (! alphaWeightMap.containsKey(value1) || ! alphaWeightMap.containsKey(value2)) return 0;
double a1 = alphaWeightMap.get(value1);
double a2 = alphaWeightMap.get(value2);
return mass > 0 ? ( -(a1*a2) / (mass*mass*(mass+1)) ) : 0;
}
/**
* @return An ordered simple distribution based on the maximum likelihood principle
*/
public Map<T, Degree> getDistribution() {
ValueSortedMap<T,Degree> vsMap = new ValueSortedMap<T,Degree>();
for (T x : alphaWeightMap.keySet())
vsMap.put(x,new SimpleDegree(getMode(x)));
return vsMap;
}
public double getLikelihood(DiscreteProbabilityDistribution<T> testDistribution) {
Map<T,Degree> distr = testDistribution.getDistribution();
double A = -lnGamma(mass);
Iterator<T> iter = alphaWeightMap.keySet().iterator();
while (iter.hasNext()) {
T key = iter.next();
double alpha = alphaWeightMap.get(key);
double x = distr.containsKey(key) ?
distr.get(key).getValue()
: 1.0;
A = A + lnGamma(alphaWeightMap.get(key)) + (alpha-1)*Math.log(x);
}
return Math.exp(A);
}
public SimpleDegree getLikelihoodDegree(DiscreteProbabilityDistribution<T> testDistribution) {
return new SimpleDegree(getLikelihood(testDistribution));
}
public IntervalDegree getLikelihoodRange(Collection<T> values, double threshold) {
throw new UnsupportedOperationException("TODO");
}
public String toString() {
return "(Dirichlet) : {" + serialize() + "}";
}
public String serialize() {
StringBuilder sb = new StringBuilder();
Iterator<T> iter = alphaWeightMap.keySet().iterator();
while (iter.hasNext()) {
T elem = iter.next();
sb.append(elem).append("/").append(getDegree(elem).getValue());
if (iter.hasNext())
sb.append(", ");
}
return sb.toString();
}
public boolean isDiscrete() {
return true;
}
public boolean isNormalized() {
return normalized;
}
public void setNormalized(boolean normalized) {
this.normalized = normalized;
}
/**
* Uses Lanczos' approx to compute logGamma(x)
* @param x
* @return lnGamma(x)
*/
protected static double lnGamma(double x) {
double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5);
double ser = 1.0 + 76.18009173 / (x + 0) - 86.50532033 / (x + 1)
+ 24.01409822 / (x + 2) - 1.231739516 / (x + 3)
+ 0.00120858003 / (x + 4) - 0.00000536382 / (x + 5);
return tmp + Math.log(ser * Math.sqrt(2 * Math.PI));
}
/**
* Uses Lanczos' approx to compute Gamma(x)
* @param x
* @return Gamma(x)
*/
static double gamma(double x) { return Math.exp(lnGamma(x)); }
protected static double GAMMA = 0.5772156649015328606065120900824024;
protected static double LN2 = Math.log(2);
protected static double Kncoe[] = { .30459198558715155634315638246624251,
.72037977439182833573548891941219706, -.12454959243861367729528855995001087,
.27769457331927827002810119567456810e-1, -.67762371439822456447373550186163070e-2,
.17238755142247705209823876688592170e-2, -.44817699064252933515310345718960928e-3,
.11793660000155572716272710617753373e-3, -.31253894280980134452125172274246963e-4,
.83173997012173283398932708991137488e-5, -.22191427643780045431149221890172210e-5,
.59302266729329346291029599913617915e-6, -.15863051191470655433559920279603632e-6,
.42459203983193603241777510648681429e-7, -.11369129616951114238848106591780146e-7,
.304502217295931698401459168423403510e-8, -.81568455080753152802915013641723686e-9,
.21852324749975455125936715817306383e-9, -.58546491441689515680751900276454407e-10,
.15686348450871204869813586459513648e-10, -.42029496273143231373796179302482033e-11,
.11261435719264907097227520956710754e-11, -.30174353636860279765375177200637590e-12,
.80850955256389526647406571868193768e-13, -.21663779809421233144009565199997351e-13,
.58047634271339391495076374966835526e-14, -.15553767189204733561108869588173845e-14,
.41676108598040807753707828039353330e-15, -.11167065064221317094734023242188463e-15 } ;
/**
* Digamma function approximation (unverified)
* Source:
* http://arXiv.org/abs/math.CA/0403344
* http://www.strw.leidenuniv.nl/~mathar/progs/digamma.c
* @param x
* @return diGamma(x)
*/
double diGamma(double x)
{
/* force into the interval 1..3 */
if(x < 0.0)
return diGamma(1.0-x)+Math.PI/Math.tan(Math.PI*(1.0-x)); /* reflection formula */
else if(x < 1.0)
return diGamma(1.0+x)-1.0/x;
else if (x == 1.0)
return -GAMMA;
else if (x == 2.0)
return 1.0-GAMMA;
else if (x == 3.0)
return 1.5-GAMMA;
else if (x > 3.0)
/* duplication formula */
return 0.5*(diGamma(0.5*x)+diGamma(0.5*(x+1.0)))+LN2;
else
{
double Tn_1 = 1.0 ; /* T_{n-1}(x), started at n=1 */
double Tn = x - 2.0 ; /* T_{n}(x) , started at n=1 */
double resul = Kncoe[0] + Kncoe[1]*Tn ;
x -= 2.0;
for(int n = 2; n < Kncoe.length; n++)
{
double Tn1 = 2.0 * x * Tn - Tn_1 ; /* Chebyshev recursion, Eq. 22.7.4 Abramowitz-Stegun */
resul += Kncoe[n]*Tn1;
Tn_1 = Tn;
Tn = Tn1;
}
return resul;
}
}
}