package dist.test; import util.linalg.DenseVector; import util.linalg.RectangularMatrix; import dist.Distribution; import dist.MixtureDistribution; import dist.DiscreteDistribution; import dist.MultivariateGaussian; import shared.DataSet; import shared.Instance; /** * Testing * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class MixtureDistributionTest { /** * The test main * @param args ignored */ public static void main(String[] args) throws Exception { Instance[] instances = new Instance[100]; MultivariateGaussian mga = new MultivariateGaussian(new DenseVector(new double[] {100, 100, 100}), RectangularMatrix.eye(3).times(.01)); MultivariateGaussian mgb = new MultivariateGaussian(new DenseVector(new double[] {-1, -1, -1}), RectangularMatrix.eye(3).times(10)); for (int i = 0; i < instances.length; i++) { if (Distribution.random.nextBoolean()) { instances[i] = mga.sample(); } else { instances[i] = mgb.sample(); } System.out.println(instances[i]); } DataSet set = new DataSet(instances); MixtureDistribution md = new MixtureDistribution(new Distribution[] { new MultivariateGaussian(new DenseVector(new double[] {120, 80, 100}), RectangularMatrix.eye(3).times(1)), new MultivariateGaussian(new DenseVector(new double[] {-1, -6, -5}), RectangularMatrix.eye(3).times(1))}, DiscreteDistribution.random(2).getProbabilities()); System.out.println(md); for (int i = 0; i < 20; i++) { md.estimate(set); System.out.println(md); } System.out.println(md); for (int i = 0; i < 30; i++) { System.out.println(md.sample(null)); } } }