/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.hmm.distributions;
import java.util.Arrays;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
/**
* A discrete distribution is a distribution with a finite set of states that it
* can be in.
*
*/
public class DiscreteDistribution implements StateDistribution {
/**
* The serial id.
*/
private static final long serialVersionUID = 1L;
/**
* The probabilities of moving between states.
*/
private final double[][] probabilities;
/**
* Construct a discrete distribution with the specified probabilities.
* @param theProbabilities The probabilities.
*/
public DiscreteDistribution(final double[][] theProbabilities) {
if (theProbabilities.length == 0) {
throw new IllegalArgumentException("Invalid empty array");
}
this.probabilities = new double[theProbabilities.length][];
for (int i = 0; i < theProbabilities.length; i++) {
if (theProbabilities[i].length == 0) {
throw new IllegalArgumentException("Invalid empty array");
}
this.probabilities[i] = new double[theProbabilities[i].length];
for (int j = 0; j < probabilities[i].length; j++) {
if ((this.probabilities[i][j] = theProbabilities[i][j]) < 0.0) {
throw new IllegalArgumentException();
}
}
}
}
/**
* Construct a discrete distribution.
* @param cx The count of each.
*/
public DiscreteDistribution(final int[] cx) {
this.probabilities = new double[cx.length][];
for (int i = 0; i < cx.length; i++) {
int c = cx[i];
this.probabilities[i] = new double[c];
for (int j = 0; j < c; j++) {
this.probabilities[i][j] = 1.0 / c;
}
}
}
/**
* @return A clone of the distribution.
*/
@Override
public DiscreteDistribution clone() {
try {
return (DiscreteDistribution) super.clone();
} catch (final CloneNotSupportedException e) {
throw new AssertionError(e);
}
}
/**
* Fit this distribution to the specified data.
* @param co THe data to fit to.
*/
@Override
public void fit(final MLDataSet co) {
if (co.size() < 1) {
throw new IllegalArgumentException("Empty observation set");
}
for (int i = 0; i < this.probabilities.length; i++) {
for (int j = 0; j < this.probabilities[i].length; j++) {
this.probabilities[i][j] = 0.0;
}
for (final MLDataPair o : co) {
this.probabilities[i][(int) o.getInput().getData(i)]++;
}
for (int j = 0; j < this.probabilities[i].length; j++) {
this.probabilities[i][j] /= co.size();
}
}
}
/**
* Fit this distribution to the specified data, with weights.
* @param co The data to fit to.
* @param weights The weights.
*/
@Override
public void fit(final MLDataSet co, final double[] weights) {
if ((co.size() < 1) || (co.size() != weights.length)) {
throw new IllegalArgumentException();
}
for (int i = 0; i < this.probabilities.length; i++) {
Arrays.fill(this.probabilities[i], 0.0);
int j = 0;
for (final MLDataPair o : co) {
this.probabilities[i][(int) o.getInput().getData(i)] += weights[j++];
}
}
}
/**
* Generate a random sequence.
* @return The next element.
*/
@Override
public MLDataPair generate() {
final MLData result = new BasicMLData(this.probabilities.length);
for (int i = 0; i < this.probabilities.length; i++) {
double rand = Math.random();
result.setData(i, this.probabilities[i].length - 1);
for (int j = 0; j < (this.probabilities[i].length - 1); j++) {
if ((rand -= this.probabilities[i][j]) < 0.0) {
result.setData(i, j);
break;
}
}
}
return new BasicMLDataPair(result);
}
/**
* Determine the probability of the specified data pair.
* @param o THe data pair.
*/
@Override
public double probability(final MLDataPair o) {
double result = 1;
for (int i = 0; i < this.probabilities.length; i++) {
if (o.getInput().getData(i) > (this.probabilities[i].length - 1)) {
throw new IllegalArgumentException("Wrong observation value");
}
result *= this.probabilities[i][(int) o.getInput().getData(i)];
}
return result;
}
/**
* @return The state probabilities.
*/
public double[][] getProbabilities() {
return this.probabilities;
}
}