/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.commons.math4.distribution; import java.util.ArrayList; import java.util.List; import org.apache.commons.math4.exception.DimensionMismatchException; import org.apache.commons.math4.exception.MathArithmeticException; import org.apache.commons.math4.exception.NotPositiveException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.math4.util.Pair; /** * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model"> * mixture model</a> distributions. * * @param <T> Type of the mixture components. * * @since 3.1 */ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution> extends AbstractMultivariateRealDistribution { /** Normalized weight of each mixture component. */ private final double[] weight; /** Mixture components. */ private final List<T> distribution; /** * Creates a mixture model from a list of distributions and their * associated weights. * * @param components Distributions from which to sample. * @throws NotPositiveException if any of the weights is negative. * @throws DimensionMismatchException if not all components have the same * number of variables. */ public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { super(components.get(0).getSecond().getDimension()); final int numComp = components.size(); final int dim = getDimension(); double weightSum = 0; for (int i = 0; i < numComp; i++) { final Pair<Double, T> comp = components.get(i); if (comp.getSecond().getDimension() != dim) { throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); } if (comp.getFirst() < 0) { throw new NotPositiveException(comp.getFirst()); } weightSum += comp.getFirst(); } // Check for overflow. if (Double.isInfinite(weightSum)) { throw new MathArithmeticException(LocalizedFormats.OVERFLOW); } // Store each distribution and its normalized weight. distribution = new ArrayList<>(); weight = new double[numComp]; for (int i = 0; i < numComp; i++) { final Pair<Double, T> comp = components.get(i); weight[i] = comp.getFirst() / weightSum; distribution.add(comp.getSecond()); } } /** {@inheritDoc} */ @Override public double density(final double[] values) { double p = 0; for (int i = 0; i < weight.length; i++) { p += weight[i] * distribution.get(i).density(values); } return p; } /** * Gets the distributions that make up the mixture model. * * @return the component distributions and associated weights. */ public List<Pair<Double, T>> getComponents() { final List<Pair<Double, T>> list = new ArrayList<>(weight.length); for (int i = 0; i < weight.length; i++) { list.add(new Pair<>(weight[i], distribution.get(i))); } return list; } /** {@inheritDoc} */ @Override public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) { return new MixtureSampler(rng); } /** * Sampler. */ private class MixtureSampler implements MultivariateRealDistribution.Sampler { /** RNG */ private final UniformRandomProvider rng; /** Sampler for each of the distribution in the mixture. */ private final MultivariateRealDistribution.Sampler[] samplers; /** * @param generator RNG. */ MixtureSampler(UniformRandomProvider generator) { rng = generator; samplers = new MultivariateRealDistribution.Sampler[weight.length]; for (int i = 0; i < weight.length; i++) { samplers[i] = distribution.get(i).createSampler(rng); } } /** {@inheritDoc} */ @Override public double[] sample() { // Sampled values. double[] vals = null; // Determine which component to sample from. final double randomValue = rng.nextDouble(); double sum = 0; for (int i = 0; i < weight.length; i++) { sum += weight[i]; if (randomValue <= sum) { // pick model i vals = samplers[i].sample(); break; } } if (vals == null) { // This should never happen, but it ensures we won't return a null in // case the loop above has some floating point inequality problem on // the final iteration. vals = samplers[weight.length - 1].sample(); } return vals; } } }