/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * BayesNetB2.java * Copyright (C) 2001 Remco Bouckaert * */ package weka.classifiers.bayes; import java.io.*; import java.util.*; import weka.core.*; import weka.estimators.*; import weka.classifiers.*; /** * Class for a Bayes Network classifier based on Buntines hill climbing algorithm for * learning structure, but augmented to allow arc reversal as an operation. * Works with nominal variables only. * * @author Remco Bouckaert (rrb@xm.co.nz) * @version $Revision: 1.1.1.1 $ */ public class BayesNetB2 extends BayesNetB { /** * buildStructure determines the network structure/graph of the network * with Buntines greedy hill climbing algorithm, restricted by its initial * structure (which can be an empty graph, or a Naive Bayes graph. */ public void buildStructure() throws Exception { // determine base scores double[] fBaseScores = new double[m_Instances.numAttributes()]; int nNrOfAtts = m_Instances.numAttributes(); for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) { fBaseScores[iAttribute] = CalcNodeScore(iAttribute); } // Determine initial structure by finding a good parent-set for classification // node using greedy search int iAttribute = m_Instances.classIndex(); double fBestScore = fBaseScores[iAttribute]; // ///////////////////////////////////////////////////////////////////////////////////////// /* * int nBestAttribute1 = -1; * int nBestAttribute2 = -1; * for (int iAttribute1 = 0; iAttribute1 < m_Instances.numAttributes(); iAttribute1++) { * if (iAttribute != iAttribute1) { * for (int iAttribute2 = 0; iAttribute2 < iAttribute1; iAttribute2++) { * if (iAttribute != iAttribute2) { * m_ParentSets[iAttribute].AddParent(iAttribute1, m_Instances); * double fScore = CalcScoreWithExtraParent(iAttribute, iAttribute2); * m_ParentSets[iAttribute].DeleteLastParent(m_Instances); * if (fScore > fBestScore) { * fBestScore = fScore; * nBestAttribute1 = iAttribute1; * nBestAttribute2 = iAttribute2; * } * } * } * } * } * if (nBestAttribute1 != -1) { * m_ParentSets[iAttribute].AddParent(nBestAttribute1, m_Instances); * m_ParentSets[iAttribute].AddParent(nBestAttribute2, m_Instances); * fBaseScores[iAttribute] = fBestScore; * System.out.println("Added " + nBestAttribute1 + " & " + nBestAttribute2); * } */ int m_nMaxNrOfClassifierParents = 4; // ///////////////////////////////////////////////////////////////////////////////////////// // double fBestScore = CalcNodeScore(iAttribute); boolean bProgress = true; while (bProgress && m_ParentSets[iAttribute].GetNrOfParents() < m_nMaxNrOfClassifierParents) { int nBestAttribute = -1; for (int iAttribute2 = 0; iAttribute2 < m_Instances.numAttributes(); iAttribute2++) { if (iAttribute != iAttribute2) { double fScore = CalcScoreWithExtraParent(iAttribute, iAttribute2); if (fScore > fBestScore) { fBestScore = fScore; nBestAttribute = iAttribute2; } } } if (nBestAttribute != -1) { m_ParentSets[iAttribute].AddParent(nBestAttribute, m_Instances); fBaseScores[iAttribute] = fBestScore; } else { bProgress = false; } } // Recalc Base scores // Correction for Naive Bayes structures: delete arcs from classification node to children for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetNrOfParents(); iParent++) { int nParentNode = m_ParentSets[iAttribute].GetParent(iParent); if (IsArc(nParentNode, iAttribute)) { m_ParentSets[nParentNode].DeleteLastParent(m_Instances); } // recalc base scores fBaseScores[nParentNode] = CalcNodeScore(nParentNode); } // super.buildStructure(); // Do algorithm B from here onwards // cache scores & whether adding an arc makes sense boolean[][] bAddArcMakesSense = new boolean[nNrOfAtts][nNrOfAtts]; double[][] fScore = new double[nNrOfAtts][nNrOfAtts]; for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { // only bother maintaining scores if adding parent does not violate the upper bound on nr of parents for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { bAddArcMakesSense[iAttributeHead][iAttributeTail] = AddArcMakesSense(iAttributeHead, iAttributeTail); if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) { fScore[iAttributeHead][iAttributeTail] = CalcScoreWithExtraParent(iAttributeHead, iAttributeTail); } } } } bProgress = true; // go do the hill climbing while (bProgress) { bProgress = false; int nBestAttributeTail = -1; int nBestAttributeHead = -1; double fBestDeltaScore = 0.0; // find best arc to add for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) { if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) { // System.out.println("gain " + iAttributeTail + " -> " + iAttributeHead + ": "+ (fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead])); if (fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead] > fBestDeltaScore) { if (AddArcMakesSense(iAttributeHead, iAttributeTail)) { fBestDeltaScore = fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead]; nBestAttributeTail = iAttributeTail; nBestAttributeHead = iAttributeHead; } else { bAddArcMakesSense[iAttributeHead][iAttributeTail] = false; } } } } } } if (nBestAttributeHead >= 0) { // update network structure // System.out.println("Added " + nBestAttributeTail + " -> " + nBestAttributeHead); m_ParentSets[nBestAttributeHead].AddParent(nBestAttributeTail, m_Instances); if (m_ParentSets[nBestAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) { // only bother updating scores if adding parent does not violate the upper bound on nr of parents fBaseScores[nBestAttributeHead] += fBestDeltaScore; // System.out.println(fScore[nBestAttributeHead][nBestAttributeTail] + " " + fBaseScores[nBestAttributeHead] + " " + fBestDeltaScore); for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) { bAddArcMakesSense[nBestAttributeHead][iAttributeTail] = AddArcMakesSense(nBestAttributeHead, iAttributeTail); if (bAddArcMakesSense[nBestAttributeHead][iAttributeTail]) { fScore[nBestAttributeHead][iAttributeTail] = CalcScoreWithExtraParent(nBestAttributeHead, iAttributeTail); // System.out.println(iAttributeTail + " -> " + nBestAttributeHead + ": " + fScore[nBestAttributeHead][iAttributeTail]); } } } bProgress = true; } } } // buildStructure /** * IsArc checks whether the arc from iAttributeTail to iAttributeHead already exists * * @param index of the attribute that becomes head of the arrow * @param index of the attribute that becomes tail of the arrow */ private boolean IsArc(int iAttributeHead, int iAttributeTail) { for (int iParent = 0; iParent < m_ParentSets[iAttributeHead].GetNrOfParents(); iParent++) { if (m_ParentSets[iAttributeHead].GetParent(iParent) == iAttributeTail) { return true; } } return false; } // IsArc /** * AddArcMakesSense checks whether adding the arc from iAttributeTail to iAttributeHead * does not already exists and does not introduce a cycle * * @param index of the attribute that becomes head of the arrow * @param index of the attribute that becomes tail of the arrow */ private boolean AddArcMakesSense(int iAttributeHead, int iAttributeTail) { if (iAttributeHead == iAttributeTail) { return false; } // sanity check: arc should not be in parent set already if (IsArc(iAttributeHead, iAttributeTail)) { return false; } // sanity check: arc should not introduce a cycle int nNodes = m_Instances.numAttributes(); boolean[] bDone = new boolean[nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { bDone[iNode] = false; } // check for cycles m_ParentSets[iAttributeHead].AddParent(iAttributeTail, m_Instances); for (int iNode = 0; iNode < nNodes; iNode++) { // find a node for which all parents are 'done' boolean bFound = false; for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) { if (!bDone[iNode2]) { boolean bHasNoParents = true; for (int iParent = 0; iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) { if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) { bHasNoParents = false; } } if (bHasNoParents) { bDone[iNode2] = true; bFound = true; } } } if (!bFound) { m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances); return false; } } m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances); return true; } // AddArcMakesCycle /** * ReverseArcMakesCycle checks whether the arc from iAttributeTail to * iAttributeHead exists and reversing does not introduce a cycle * * @param index of the attribute that is head of the arrow * @param index of the attribute that is tail of the arrow */ private boolean ReverseArcMakesCycle(int iAttributeHead, int iAttributeTail) { if (iAttributeHead == iAttributeTail) { return false; } // sanity check: arc should be in parent set already if (!IsArc(iAttributeHead, iAttributeTail)) { return false; } // sanity check: arc should not introduce a cycle int nNodes = m_Instances.numAttributes(); boolean[] bDone = new boolean[nNodes]; for (int iNode = 0; iNode < nNodes; iNode++) { bDone[iNode] = false; } // check for cycles m_ParentSets[iAttributeTail].AddParent(iAttributeHead, m_Instances); for (int iNode = 0; iNode < nNodes; iNode++) { // find a node for which all parents are 'done' boolean bFound = false; for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) { if (!bDone[iNode2]) { boolean bHasNoParents = true; for (int iParent = 0; iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) { if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) { // this one has a parent which is not 'done' UNLESS it is the arc to be reversed if (iNode2 != iAttributeHead || m_ParentSets[iNode2].GetParent(iParent) != iAttributeTail) { bHasNoParents = false; } } } if (bHasNoParents) { bDone[iNode2] = true; bFound = true; } } } if (!bFound) { m_ParentSets[iAttributeTail].DeleteLastParent(m_Instances); return false; } } m_ParentSets[iAttributeTail].DeleteLastParent(m_Instances); return true; } // ReverseArcMakesCycle /** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new BayesNetB2(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } } // main } // class BayesNetB2