/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.chance.distribution.probability.dirichlet;
import org.drools.chance.degree.Degree;
import org.drools.chance.degree.DegreeType;
import org.drools.chance.degree.ChanceDegreeTypeRegistry;
import org.drools.chance.distribution.DiscreteDomainDistribution;
import org.drools.chance.distribution.DiscreteProbabilityDistribution;
import org.drools.chance.distribution.Distribution;
import org.drools.chance.distribution.DistributionStrategies;
import org.drools.chance.core.util.ValueSortedMap;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
/**
* Strategy and level III factory for Dirichlet probability distributions
* @param <T>
*/
public class DirichletDistributionStrategy<T> implements DistributionStrategies<T> {
private DegreeType degreeType;
private Class<T> domainType;
private Constructor degreeStringConstr = null;
DirichletDistributionStrategy(DegreeType degreeType, Class<T> domainType){
this.degreeType = degreeType;
this.domainType = domainType;
}
private Constructor getDegreeStringConstructor() {
if (degreeStringConstr == null)
degreeStringConstr = ChanceDegreeTypeRegistry.getSingleInstance().getConstructorByString(degreeType);
return degreeStringConstr;
}
public Distribution<T> merge(Distribution<T> current,
Distribution<T> newBit) {
if (current instanceof DirichletDistribution && newBit instanceof DirichletDistribution) {
DirichletDistribution<T> curr = (DirichletDistribution<T>) current;
Map<T,Double> a1 = curr.getAlphaWeights();
Map<T,Double> a2 = ((DirichletDistribution<T>) newBit).getAlphaWeights();
double m = curr.getMass();
Iterator<T> it = new HashSet<T>(a1.keySet()).iterator();
while (it.hasNext()) {
T key = it.next();
if (a2.containsKey(key)) {
double x = a2.get(key);
a1.put(key,a1.get(key) + x );
m += x;
}
}
for (T key : a2.keySet()) {
if (! a1.containsKey(key)) {
double x = a2.get(key);
a1.put(key, x );
m += x;
}
}
curr.setMass(m);
return curr;
} else if ( current instanceof DirichletDistribution && newBit instanceof DiscreteDomainDistribution) {
DirichletDistribution<T> curr = (DirichletDistribution<T>) current;
Map<T,Double> a1 = curr.getAlphaWeights();
Map<T,Degree> a2 = ((DiscreteProbabilityDistribution<T>) newBit).getDistribution();
double m = curr.getMass();
for (T key : a1.keySet()) {
if (a2.containsKey(key)) {
double x = a2.get(key).getValue();
a1.put(key,a1.get(key) + x );
m += x;
}
}
for (T key : a2.keySet()) {
if (! a1.containsKey(key)) {
double x = a2.get(key).getValue();
a1.put(key, x );
m += x;
}
}
curr.setMass(m);
return curr;
} else {
throw new UnsupportedOperationException("Dirichlet Strategies : unable to merge "
+ current.getClass().getName() + " with " + newBit.getClass().getName());
}
}
public Distribution<T> merge(Distribution<T> current,
Distribution<T> newBit, String strategy) {
return merge(current,newBit);
}
public Distribution<T> merge(Distribution<T> current,
Distribution<T> newBit, Object... params) {
return merge(current,newBit);
}
public Distribution<T> mergeAsNew(Distribution<T> current,
Distribution<T> newBit) {
if (current instanceof DirichletDistribution && newBit instanceof DirichletDistribution) {
DirichletDistribution<T> distr = new DirichletDistribution<T>();
Map<T,Double> a = distr.getAlphaWeights();
double m = 0;
DirichletDistribution<T> curr = (DirichletDistribution<T>) current;
Map<T,Double> a1 = curr.getAlphaWeights();
Map<T,Double> a2 = ((DirichletDistribution<T>) newBit).getAlphaWeights();
for (T key : a1.keySet()) {
if (a2.containsKey(key)) {
double x = a1.get(key) + a2.get(key);
a.put(key, x );
m += x;
} else {
double x = a1.get(key);
a.put(key, x );
m += x;
}
}
for (T key : a2.keySet()) {
if (! a1.containsKey(key)) {
double x = a2.get(key);
a.put(key, x );
m += x;
}
}
distr.setMass(m);
return distr;
} else if ( current instanceof DirichletDistribution && newBit instanceof DiscreteDomainDistribution) {
DirichletDistribution<T> distr = new DirichletDistribution<T>();
Map<T,Double> a = distr.getAlphaWeights();
double m = 0;
DirichletDistribution<T> curr = (DirichletDistribution<T>) current;
Map<T,Double> a1 = curr.getAlphaWeights();
Map<T,Degree> a2 = ((DiscreteProbabilityDistribution<T>) newBit).getDistribution();
for (T key : a1.keySet()) {
if (a2.containsKey(key)) {
double x = a1.get(key) + a2.get(key).getValue();
a.put(key, x );
m += x;
} else {
double x = a1.get(key);
a.put(key, x );
m += x;
}
}
for (T key : a2.keySet()) {
if (! a1.containsKey(key)) {
double x = a2.get(key).getValue();
a.put(key, x );
m += x;
}
}
distr.setMass(m);
return distr;
} else {
throw new UnsupportedOperationException("Dirichlet Strategies : unable to merge "
+ current.getClass().getName() + " with " + newBit.getClass().getName());
}
}
public Distribution<T> mergeAsNew(Distribution<T> current,
Distribution<T> newBit, String strategy) {
return mergeAsNew(current, newBit);
}
public Distribution<T> mergeAsNew(Distribution<T> current,
Distribution<T> newBit, Object... params) {
return mergeAsNew(current, newBit);
}
public Distribution<T> remove(Distribution<T> current, Distribution<T> newBit) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> remove(Distribution<T> current, Distribution<T> newBit, String strategy) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> remove(Distribution<T> current, Distribution<T> newBit, Object... params) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> removeAsNew(Distribution<T> current, Distribution<T> newBit) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> removeAsNew(Distribution<T> current, Distribution<T> newBit, String strategy) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> removeAsNew(Distribution<T> current, Distribution<T> newBit, Object... params) {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
public void normalize(Distribution<T> distr) {
//To change body of implemented methods use File | Settings | File Templates.
}
public Distribution<T> newDistribution() {
DirichletDistribution<T> dist = new DirichletDistribution<T>(); ;
return dist;
}
public Distribution<T> newDistribution(Set<T> focalElements) {
DirichletDistribution<T> dist = new DirichletDistribution<T>();
for (T value : focalElements) {
dist.getAlphaWeights().put(value,1.0);
}
dist.setMass(focalElements.size());
return dist;
}
public Distribution<T> newDistribution(Map<? extends T, ? extends Degree> elements) {
DirichletDistribution<T> dist = new DirichletDistribution<T>();
double m = 0;
for (T value : elements.keySet()) {
double x = elements.get(value).getValue();
dist.getAlphaWeights().put(value,x);
m += x;
}
dist.setMass(m);
return dist;
}
public T toCrispValue(Distribution<T> dist) {
ValueSortedMap<T,Double> aw = ((DirichletDistribution<T>) dist).getAlphaWeights();
return aw.isEmpty() ? null : (T) aw.keySet().iterator().next();
}
public T toCrispValue(Distribution<T> dist, String strategy) {
return toCrispValue(dist);
}
public T toCrispValue(Distribution<T> dist, Object... params) {
return toCrispValue(dist);
}
public T sample(Distribution<T> dist) {
DirichletDistribution<T> diric = (DirichletDistribution<T>) dist;
Iterator<T> iter = diric.getSupport().iterator();
double p = Math.random();
double acc = 0.0;
T result = null;
while ( acc < p ) {
T elem = iter.next();
double x = dist.getDegree(elem).getValue();
result = elem;
acc += x;
}
return result;
}
public T sample(Distribution<T> dist, String strategy) {
return sample(dist);
}
public T sample(Distribution<T> dist, Object... params) {
return sample(dist);
}
public Distribution<T> toDistribution(T value) {
return buildDistributionFromSingleObservation(value,1.0);
}
public Distribution<T> toDistribution(T value, String strategy) {
if ("spike".equals(strategy)) {
return buildDistributionFromSingleObservation(value,Double.MAX_VALUE);
}
return toDistribution(value);
}
public Distribution<T> toDistribution(T value, Object... params) {
return toDistribution(value);
}
protected Distribution<T> buildDistributionFromSingleObservation(T value, double wgt) {
DirichletDistribution<T> dist = new DirichletDistribution<T>();
dist.getAlphaWeights().put(value,wgt);
dist.setMass(wgt);
return dist;
}
public Distribution<T> parse(String distrAsString) {
DirichletDistribution<T> dist = new DirichletDistribution<T>();
double m = 0;
StringTokenizer tok = new StringTokenizer(distrAsString,",");
while (tok.hasMoreElements()) {
String pair = tok.nextToken().trim();
StringTokenizer sub = new StringTokenizer(pair,"/");
try {
T value = (T) domainType.getConstructor(String.class).newInstance(sub.nextToken().trim());
double x = Double.valueOf(sub.nextToken().trim());
dist.getAlphaWeights().put(value,x);
m += x;
} catch (NoSuchMethodException nsme) {
nsme.printStackTrace();
} catch (IllegalAccessException iae) {
iae.printStackTrace();
} catch (InstantiationException ie) {
ie.printStackTrace();
} catch (InvocationTargetException ite) {
ite.printStackTrace();
}
}
dist.setMass(m);
return dist;
}
}