/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * BayesNetB.java * Copyright (C) 2001 Remco Bouckaert * */ package weka.classifiers.bayes; import java.io.*; import java.util.*; import weka.core.*; import weka.estimators.*; import weka.classifiers.*; /** * Class for a Bayes Network classifier based on a hill climbing algorithm for * learning structure as described in Buntine, W. (1991). Theory refinement on * Bayesian networks. In Proceedings of Seventh Conference on Uncertainty in * Artificial Intelligence, Los Angeles, CA, pages 52--60. Morgan Kaufmann. * Works with nominal variables and no missing values only. * * @author Remco Bouckaert (rrb@xm.co.nz) * @version $Revision: 1.1.1.1 $ */ public class BayesNetB extends BayesNet { /** * buildStructure determines the network structure/graph of the network * with Buntines greedy hill climbing algorithm, restricted by its initial * structure (which can be an empty graph, or a Naive Bayes graph. */ public void buildStructure() throws Exception { // determine base scores double[] fBaseScores = new double[m_Instances.numAttributes()]; int nNrOfAtts = m_Instances.numAttributes(); for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) { fBaseScores[iAttribute] = CalcNodeScore(iAttribute); } // B algorithm: greedy search (not restricted by ordering like K2) boolean bProgress = true; // cache scores & whether adding an arc makes sense boolean[][] bAddArcMakesSense = new boolean[nNrOfAtts][nNrOfAtts]; double[][] fScore = new double[nNrOfAtts][nNrOfAtts]; for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { // only bother maintaining scores if adding parent does not violate the upper bound on nr of parents for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { bAddArcMakesSense[iAttributeHead][iAttributeTail] = AddArcMakesSense(iAttributeHead, iAttributeTail); if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) { fScore[iAttributeHead][iAttributeTail] = CalcScoreWithExtraParent(iAttributeHead, iAttributeTail); } } } } // go do the hill climbing while (bProgress) { bProgress = false; int nBestAttributeTail = -1; int nBestAttributeHead = -1; double fBestDeltaScore = 0.0; // find best arc to add for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) { if (fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead] > fBestDeltaScore) { if (AddArcMakesSense(iAttributeHead, iAttributeTail)) { fBestDeltaScore = fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead]; nBestAttributeTail = iAttributeTail; nBestAttributeHead = iAttributeHead; } else { bAddArcMakesSense[iAttributeHead][iAttributeTail] = false; } } } } } } if (nBestAttributeHead >= 0) { // update network structure m_ParentSets[nBestAttributeHead].AddParent(nBestAttributeTail, m_Instances); if (m_ParentSets[nBestAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { // only bother updating scores if adding parent does not violate the upper bound on nr of parents fBaseScores[nBestAttributeHead] += fBestDeltaScore; for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { bAddArcMakesSense[nBestAttributeHead][iAttributeTail] = AddArcMakesSense(nBestAttributeHead, iAttributeTail); if (bAddArcMakesSense[nBestAttributeHead][iAttributeTail]) { fScore[nBestAttributeHead][iAttributeTail] = CalcScoreWithExtraParent(nBestAttributeHead, iAttributeTail); } } } bProgress = true; } } } // buildStructure /** * AddArcMakesSense checks whether adding the arc from iAttributeTail to iAttributeHead * does not already exists and does not introduce a cycle * * @param iAttributeHead index of the attribute that becomes head of the arrow * @param iAttributeTail index of the attribute that becomes tail of the arrow * @return true if adding arc is allowed, otherwise false */ private boolean AddArcMakesSense(int iAttributeHead, int iAttributeTail) { if (iAttributeHead == iAttributeTail) { return false; } // sanity check: arc should not be in parent set already for (int iParent = 0; iParent < m_ParentSets[iAttributeHead].GetNrOfParents(); iParent++) { if (m_ParentSets[iAttributeHead].GetParent(iParent) == iAttributeTail) { return false; } } // sanity check: arc should not introduce a cycle int nNodes = m_Instances.numAttributes(); boolean[] bDone = new boolean[nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { bDone[iNode] = false; } // check for cycles m_ParentSets[iAttributeHead].AddParent(iAttributeTail, m_Instances); for (int iNode = 0; iNode < nNodes; iNode++) { // find a node for which all parents are 'done' boolean bFound = false; for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) { if (!bDone[iNode2]) { boolean bHasNoParents = true; for (int iParent = 0; iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) { if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) { bHasNoParents = false; } } if (bHasNoParents) { bDone[iNode2] = true; bFound = true; } } } if (!bFound) { m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances); return false; } } m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances); return true; } // AddArcMakesCycle /** * This will return a string describing the classifier. * @return The string. */ public String globalInfo() { return "This Bayes Network learning algorithm uses a hill climbing algorithm" + " without restriction on the order of variables"; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new BayesNetB(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } } // main } // class BayesNetB