/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.bayesian.training; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.bayesian.BayesianEvent; import org.encog.ml.bayesian.BayesianNetwork; import org.encog.ml.bayesian.training.estimator.BayesEstimator; import org.encog.ml.bayesian.training.estimator.SimpleEstimator; import org.encog.ml.bayesian.training.search.k2.BayesSearch; import org.encog.ml.bayesian.training.search.k2.SearchK2; import org.encog.ml.data.MLDataSet; import org.encog.ml.train.BasicTraining; import org.encog.neural.networks.training.propagation.TrainingContinuation; /** * Train a Bayesian network. */ public class TrainBayesian extends BasicTraining { /** * What phase of training are we in? */ private enum Phase { /** * Init phase. */ Init, /** * Searching for a network structure. */ Search, /** * Search complete. */ SearchDone, /** * Finding probabilities. */ Probability, /** * Finished training. */ Finish, /** * Training terminated. */ Terminated }; /** * The phase that training is currently in. */ private Phase p = Phase.Init; /** * The data used for training. */ private final MLDataSet data; /** * The network to train. */ private final BayesianNetwork network; /** * The maximum parents a node should have. */ private final int maximumParents; /** * The method used to search for the best network structure. */ private final BayesSearch search; /** * The method used to estimate the probabilities. */ private final BayesEstimator estimator; /** * The method used to setup the initial Bayesian network. */ private BayesianInit initNetwork = BayesianInit.InitNaiveBayes; /** * Used to hold the query. */ private String holdQuery; /** * Construct a Bayesian trainer. Use K2 to search, and the SimpleEstimator * to estimate probability. Init as Naive Bayes * * @param theNetwork * The network to train. * @param theData * The data to train. * @param theMaximumParents * The max number of parents. */ public TrainBayesian(BayesianNetwork theNetwork, MLDataSet theData, int theMaximumParents) { this(theNetwork, theData, theMaximumParents, BayesianInit.InitNaiveBayes, new SearchK2(), new SimpleEstimator()); } /** * Construct a Bayesian trainer. * @param theNetwork The network to train. * @param theData The data to train with. * @param theMaximumParents The maximum number of parents. * @param theInit How to init the new Bayes network. * @param theSearch The search method. * @param theEstimator The estimation mehod. */ public TrainBayesian(BayesianNetwork theNetwork, MLDataSet theData, int theMaximumParents, BayesianInit theInit, BayesSearch theSearch, BayesEstimator theEstimator) { super(TrainingImplementationType.Iterative); this.network = theNetwork; this.data = theData; this.maximumParents = theMaximumParents; this.search = theSearch; this.search.init(this, theNetwork, theData); this.estimator = theEstimator; this.estimator.init(this, theNetwork, theData); this.initNetwork = theInit; setError(1.0); } /** * Init to Naive Bayes. */ private void initNaiveBayes() { // clear out anything from before this.network.removeAllRelations(); // locate the classification target event BayesianEvent classificationTarget = this.network .getClassificationTargetEvent(); // now link everything to this event for (BayesianEvent event : this.network.getEvents()) { if (event != classificationTarget) { network.createDependency(classificationTarget, event); } } this.network.finalizeStructure(); } /** * Handle iterations for the Init phase. */ private void iterationInit() { this.holdQuery = this.network.getClassificationStructure(); switch (this.initNetwork) { case InitEmpty: this.network.removeAllRelations(); this.network.finalizeStructure(); break; case InitNoChange: break; case InitNaiveBayes: initNaiveBayes(); break; } this.p = Phase.Search; } /** * Handle iterations for the Search phase. */ private void iterationSearch() { if (!this.search.iteration()) { this.p = Phase.SearchDone; } } /** * Handle iterations for the Search Done phase. */ private void iterationSearchDone() { this.network.finalizeStructure(); this.network.reset(); this.p = Phase.Probability; } /** * Handle iterations for the Probability phase. */ private void iterationProbability() { if (!this.estimator.iteration()) { this.p = Phase.Finish; } } /** * Handle iterations for the Finish phase. */ private void iterationFinish() { this.network.defineClassificationStructure(this.holdQuery); setError(this.network.calculateError(this.data)); this.p = Phase.Terminated; } /** * {@inheritDoc} */ @Override public boolean isTrainingDone() { if (super.isTrainingDone()) return true; else return this.p == Phase.Terminated; } /** * {@inheritDoc} */ @Override public void iteration() { postIteration(); switch (p) { case Init: iterationInit(); break; case Search: iterationSearch(); break; case SearchDone: iterationSearchDone(); break; case Probability: iterationProbability(); break; case Finish: iterationFinish(); break; } preIteration(); } /** * {@inheritDoc} */ @Override public boolean canContinue() { return false; } /** * {@inheritDoc} */ @Override public TrainingContinuation pause() { return null; } /** * {@inheritDoc} */ @Override public void resume(TrainingContinuation state) { } /** * {@inheritDoc} */ @Override public MLMethod getMethod() { return this.network; } /** * @return the network */ public BayesianNetwork getNetwork() { return network; } /** * @return the maximumParents */ public int getMaximumParents() { return maximumParents; } /** * @return The search method. */ public BayesSearch getSearch() { return this.search; } /** * @return The init method. */ public BayesianInit getInitNetwork() { return initNetwork; } /** * Set the network init method. * @param initNetwork The init method. */ public void setInitNetwork(BayesianInit initNetwork) { this.initNetwork = initNetwork; } }