/* * File State.java * * Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz * * This file is part of BEAST2. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package beast.core; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.io.PrintStream; import java.nio.file.Files; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import org.w3c.dom.Document; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import org.xml.sax.SAXException; import beast.core.util.Log; @Description("The state represents the current point in the state space, and " + "maintains values of a set of StateNodes, such as parameters and trees. " + "Furthermore, the state manages which parts of the model need to be stored/restored " + "and notified that recalculation is appropriate.") public class State extends BEASTObject { public final Input<List<StateNode>> stateNodeInput = new Input<>("stateNode", "anything that is part of the state", new ArrayList<>()); final public Input<Integer> m_storeEvery = new Input<>("storeEvery", "store the state to disk every X number of samples so that we can " + "resume computation later on if the process failed half-way.", -1); // public Input<Boolean> m_checkPoint = // new Input<>("checkpoint", "keep saved states (every X samples).", false); /** * The components of the state, for instance tree & parameters. * This represents the current state, but a copy is kept so that when * an operation is applied to the State but the proposal is not accepted, * the state can be restored. This is currently implemented by having * Operators call getEditableStateNode() at which point the requested * StateNode is copied. * Access through getNrStatNodes() and getStateNode(.). */ protected StateNode[] stateNode; /** * number of state nodes * */ private int nrOfStateNodes; public int getNrOfStateNodes() { return nrOfStateNodes; } /** * pointers to memory allocated to stateNodes and storedStateNodes * */ private StateNode[] stateNodeMem; /** * File name used for storing the state, either periodically or at the end of an MCMC chain * so that the chain can be resumed. */ private String stateFileName = "state.backup.xml"; /** The following members are involved in calculating the set of * CalculatioNodes that need to be notified when an operation * has been applied to the State. The Calculation nodes are then * store/restore/accepted/check dirtiness in partial order. */ /** * Maps a BEASTObject to a list of Outputs. * This map only contains those plug-ins that have a path to the posterior * */ private HashMap<BEASTInterface, List<BEASTInterface>> outputMap; /** * Same as m_outputMap, but only for StateNodes indexed by the StateNode number * We need this since the StateNode changes regularly, so unlike the output map * for BEASTObjects cannot be accessed by the current StateNode as key. */ private List<CalculationNode>[] stateNodeOutputs; /** * Code that represents configuration of StateNodes that have changed * during an operation. * <p/> * Every time an operation requests a StateNode, an entry is added to changeStateNodes * changeStateNodes records how many StateNodes are changed. * The code is reset when the state is stored, and every time a StateNode * is requested by an operator, changeStateNodes is updated. */ private int[] changeStateNodes; private int nrOfChangedStateNodes; /** * Maps the changed states node code to * the set of calculation nodes that is potentially affected by an operation * */ Trie trie; /** * class for quickly finding which calculation nodes need to be updated * due to state-node changes */ class Trie { List<CalculationNode> list; final Trie[] children; Trie() { children = new Trie[stateNode.length]; } /** * get entry from Trie, return null if no entry is present yet * * @param pos */ List<CalculationNode> get(final int pos) { if (pos == 0) { return list; } final Trie child = children[changeStateNodes[pos - 1]]; if (child == null) { return null; } return child.get(pos - 1); } /** * set entry int Trie, create new entries if no entry is present yet * */ void set(final List<CalculationNode> list, final int pos) { if (pos == 0) { this.list = list; return; } Trie child = children[changeStateNodes[pos - 1]]; if (child == null) { child = new Trie(); children[changeStateNodes[pos - 1]] = child; } child.set(list, pos - 1); } } @Override public void initAndValidate() { } public void initialise() { stateNode = stateNodeInput.get().toArray(new StateNode[0]); for (int i = 0; i < stateNode.length; i++) { stateNode[i].index = i; } // make itself known for (StateNode state : stateNode) { state.state = this; } nrOfStateNodes = stateNode.length; // allocate memory for StateNodes and a copy. stateNodeMem = new StateNode[nrOfStateNodes * 2]; for (int i = 0; i < nrOfStateNodes; i++) { stateNodeMem[i] = stateNode[i]; stateNodeMem[nrOfStateNodes + i] = stateNodeMem[i].copy(); } // set up data structure for encoding which StateNodes change by an operation changeStateNodes = new int[stateNode.length]; //Arrays.fill(changeStateNodes, -1); nrOfChangedStateNodes = 0; trie = new Trie(); // add the empty list for the case none of the StateNodes have changed trie.list = new ArrayList<>(); } // initAndValidate /** * return currently valid state node. This is typically called from a * CalculationNode for inspecting the value of a StateNode, not for * changing it. To change a StateNode, say from an Operator, * getEditableStateNode() should be called. * */ public StateNode getStateNode(final int _id) { return stateNode[_id]; } /** * Return StateNode that can be changed, but later restored * if necessary. If there is no copy stored already, a copy is * made first, and the StateNode is marked as being dirty. * <p/> * NB This should only be called from an Operator that wants to * change the particular StateNode through the Input.get(Operator) * method on the input associated with this StateNode. */ protected StateNode getEditableStateNode(int _id, Operator operator) { for (int i = 0; i < nrOfChangedStateNodes; i++) { if (changeStateNodes[i] == _id) { return stateNode[_id]; } } changeStateNodes[nrOfChangedStateNodes++] = _id; return stateNode[_id]; } /** * Store a State before applying an operation proposal to the state. * This copies the state for possible later restoration * but does not affect any inputs, which are all still connected * to the original StateNodes * <p/> * Also, store the state to disk for resumption of analysis later on. * * @param sample chain state number * @return true if stored to disk */ public void store(final int sample) { //Arrays.fill(changeStateNodes, -1); nrOfChangedStateNodes = 0; } /** * Restore a State after rejecting the operation proposal. * This assigns the state to the stored state. * NB this does not affect any Inputs connected to any stateNode. * */ public void restore() { for (int i = 0; i < nrOfChangedStateNodes; i++) { stateNode[changeStateNodes[i]].restore(); } } /** * Visit all calculation nodes in partial order determined by the BEASTObject-input relations * (i.e. if A is input of B then A < B). There are 4 operations that can be propagated this * way: * <p/> * store() makes sure all calculation nodes store their internal state * <p/> * checkDirtiness() makes all calculation nodes check whether they give a different answer * when interrogated by one of its outputs * <p/> * accept() allows all calculation nodes to mark themselves as being clean without further * calculation * <p/> * restore() if a proposed state is not accepted, all calculation nodes need to restore * themselves */ public void storeCalculationNodes() { final List<CalculationNode> currentSetOfCalculationNodes = getCurrentCalculationNodes(); for (final CalculationNode calculationNode : currentSetOfCalculationNodes) { calculationNode.store(); } } public void checkCalculationNodesDirtiness() { final List<CalculationNode> currentSetOfCalculationNodes = getCurrentCalculationNodes(); for (final CalculationNode calculationNode : currentSetOfCalculationNodes) { calculationNode.checkDirtiness(); } } public void restoreCalculationNodes() { final List<CalculationNode> currentSetOfCalculationNodes = getCurrentCalculationNodes(); for (final CalculationNode calculationNode : currentSetOfCalculationNodes) { calculationNode.restore(); } } public void acceptCalculationNodes() { final List<CalculationNode> currentSetOfCalculationNodes = getCurrentCalculationNodes(); for (final CalculationNode calculationNode : currentSetOfCalculationNodes) { calculationNode.accept(); } } /** * set name of state file, used when storing/restoring the state to disk * */ public void setStateFileName(final String fileName) { if (fileName != null) { stateFileName = fileName; } } /** * Print state to file. This is called either periodically or at the end * of an MCMC chain, so that the state can be resumed later on. * * @param sample TODO */ public void storeToFile(final int sample) { try { PrintStream out = new PrintStream(stateFileName + ".new"); out.print(toXML(sample)); //out.print(new XMLProducer().toXML(this)); out.close(); File newStateFile = new File(stateFileName + ".new"); File oldStateFile = new File(stateFileName); oldStateFile.delete(); // newStateFile.renameTo(oldStateFile); -- unstable under windows Files.move(newStateFile.toPath(), oldStateFile.toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING); } catch (Exception e) { e.printStackTrace(); } } /** * convert state to XML string, * The state can be reconstructed using the fromXML() method * * @param sample TODO* */ public String toXML(final int sample) { final StringBuilder buf = new StringBuilder(); buf.append("<itsabeastystatewerein version='2.0' sample='").append(sample).append("'>\n"); for (final StateNode node : stateNode) { buf.append(node.toXML()); } buf.append("</itsabeastystatewerein>\n"); return buf.toString(); } /** * Restore state from an XML fragment * */ public void fromXML(final String xml) { try { final DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); final Document doc = factory.newDocumentBuilder().parse(new ByteArrayInputStream(xml.getBytes())); doc.normalize(); final NodeList nodes = doc.getElementsByTagName("*"); final Node topNode = nodes.item(0); final NodeList children = topNode.getChildNodes(); for (int childIndex = 0; childIndex < children.getLength(); childIndex++) { final Node child = children.item(childIndex); if (child.getNodeType() == Node.ELEMENT_NODE) { final String id = child.getAttributes().getNamedItem("id").getNodeValue(); int stateNodeIndex = 0; while (!stateNode[stateNodeIndex].getID().equals(id)) { stateNodeIndex++; } final StateNode stateNode2 = stateNode[stateNodeIndex].copy(); stateNode2.fromXML(child); stateNode[stateNodeIndex].assignFromFragile(stateNode2); } } } catch (Exception e) { e.printStackTrace(); System.exit(1); } } /** * restore a state from file for resuming an MCMC chain * @throws ParserConfigurationException * @throws IOException * @throws SAXException * */ public void restoreFromFile() throws SAXException, IOException, ParserConfigurationException { Log.info.println("Restoring from file"); DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); Document doc = factory.newDocumentBuilder().parse(new File(stateFileName)); doc.normalize(); final NodeList nodes = doc.getElementsByTagName("*"); final Node topNode = nodes.item(0); final NodeList children = topNode.getChildNodes(); for (int childIndex = 0; childIndex < children.getLength(); childIndex++) { final Node child = children.item(childIndex); if (child.getNodeType() == Node.ELEMENT_NODE) { Node idNode = child.getAttributes().getNamedItem("id"); if (idNode != null) { final String id = idNode.getNodeValue(); int stateNodeIndex = 0; // An init node without ID - should not bring the house down, does it? // I have not checked if the state is restored correctly or not (JH) while (stateNode[stateNodeIndex].getID() != null && !stateNode[stateNodeIndex].getID().equals(id)) { stateNodeIndex++; if (stateNodeIndex >= stateNode.length) { Log.warning.println("Cannot restore statenode id " + id + " -- item is ignored"); break; } } if (stateNodeIndex < stateNode.length) { final StateNode stateNode2 = stateNode[stateNodeIndex].copy(); stateNode2.fromXML(child); stateNode[stateNodeIndex].assignFromFragile(stateNode2); } } else { Log.warning.println("Cannot restore statenode without id -- item is ignored"); } } } } @Override public String toString() { if (stateNode == null) { return ""; } final StringBuilder buf = new StringBuilder(); for (final StateNode node : stateNode) { buf.append(node.toString()); buf.append("\n"); } return buf.toString(); } /** * Set dirtiness to all StateNode, this means that * apart from marking all StateNode.someThingIsDirty as isDirty * parameters mark all their dimension as isDirty and * trees mark all their nodes as isDirty. */ public void setEverythingDirty(final boolean isDirty) { for (final StateNode node : stateNode) { node.setEverythingDirty(isDirty); } if (isDirty) { // happens only during debugging and start of MCMC chain for (int i = 0; i < stateNode.length; i++) { changeStateNodes[i] = i; } nrOfChangedStateNodes = stateNode.length; } } /** * Sets the posterior, needed to calculate paths of CalculationNode * that need store/restore/requireCalculation checks. * As a side effect, outputs for every beastObject in the model are calculated. * NB the output map only contains outputs on a path to the posterior BEASTObject! */ @SuppressWarnings("unchecked") public void setPosterior(BEASTObject posterior) { // first, calculate output map that maps BEASTObjects on a path // to the posterior to the list of output BEASTObjects. Strictly // speaking, this is a bit of overkill, since only // CalculationNodes need to be taken in account, but for // debugging purposes (developer forgot to derive from CalculationNode) // we keep track of the lot. outputMap = new HashMap<>(); outputMap.put(posterior, new ArrayList<>()); boolean progress = true; List<BEASTInterface> beastObjects = new ArrayList<>(); beastObjects.add(posterior); while (progress) { progress = false; // loop over plug-ins, till no more plug-ins can be added // efficiency is no issue here for (int i = 0; i < beastObjects.size(); i++) { BEASTInterface beastObject = beastObjects.get(i); try { for (BEASTInterface inputBEASTObject : beastObject.listActiveBEASTObjects()) { if (!outputMap.containsKey(inputBEASTObject)) { outputMap.put(inputBEASTObject, new ArrayList<>()); beastObjects.add(inputBEASTObject); progress = true; } if (!outputMap.get(inputBEASTObject).contains(beastObject)) { outputMap.get(inputBEASTObject).add(beastObject); progress = true; } } } catch (Exception e) { e.printStackTrace(); } } } // Set of array of StateNode outputs. Since the StateNodes have a potential // to be changing objects (when store/restore is applied) it is necessary // to use another method to find the outputs, an array in this case. stateNodeOutputs = new List[stateNode.length]; for (int i = 0; i < stateNode.length; i++) { stateNodeOutputs[i] = new ArrayList<>(); if (outputMap.containsKey(stateNode[i])) { for (BEASTInterface beastObject : outputMap.get(stateNode[i])) { if (beastObject instanceof CalculationNode) { stateNodeOutputs[i].add((CalculationNode) beastObject); } else { throw new RuntimeException("DEVELOPER ERROR: output of StateNode (" + stateNode[i].getID() + ") should be a CalculationNode, but " + beastObject.getClass().getName() + " is not."); } } } else { Log.warning.println("\nWARNING: StateNode (" + stateNode[i].getID() + ") found that has no effect on posterior!\n"); } } } // setPosterior /** * return current set of calculation nodes based on the set of StateNodes that have changed * */ private List<CalculationNode> getCurrentCalculationNodes() { List<CalculationNode> calcNodes = trie.get(nrOfChangedStateNodes); if (calcNodes != null) { // the list is pre-calculated return calcNodes; } // we need to calculate the list of CalculationNodes now try { calcNodes = calculateCalcNodePath(); } catch (Exception e) { e.printStackTrace(); System.exit(1); } trie.set(calcNodes, nrOfChangedStateNodes); // System.err.print(Arrays.toString(changeStateNodes) + ":"); // for (CalculationNode node : calcNodes) { // Log.warning.print(node.m_sID + " "); // } // System.err.println(); return calcNodes; } // getCurrentCalculationNodes /** * Collect all CalculationNodes on a path from any StateNode that is changed (as * indicated by m_changedStateNodeCode) to the posterior. Return the list in * partial order as determined by the BEASTObjects input relations. * @throws IllegalAccessException * @throws IllegalArgumentException */ private List<CalculationNode> calculateCalcNodePath() throws IllegalArgumentException, IllegalAccessException { final List<CalculationNode> calcNodes = new ArrayList<>(); // for (int i = 0; i < stateNode.length; i++) { // if (m_changedStateNodeCode.get(i)) { for (int k = 0; k < nrOfChangedStateNodes; k++) { int i = changeStateNodes[k]; // go grab the path to the Runnable // first the outputs of the StateNodes that is changed boolean progress = false; for (CalculationNode node : stateNodeOutputs[i]) { if (!calcNodes.contains(node)) { calcNodes.add(node); progress = true; } } // next the path following the outputs while (progress) { progress = false; // loop over beastObjects till no more beastObjects can be added // efficiency is no issue here, assuming the graph remains // constant for (int calcNodeIndex = 0; calcNodeIndex < calcNodes.size(); calcNodeIndex++) { CalculationNode node = calcNodes.get(calcNodeIndex); for (BEASTInterface output : outputMap.get(node)) { if (output instanceof CalculationNode) { final CalculationNode calcNode = (CalculationNode) output; if (!calcNodes.contains(calcNode)) { calcNodes.add(calcNode); progress = true; } } else { throw new RuntimeException("DEVELOPER ERROR: found a" + " non-CalculatioNode (" +output.getClass().getName() +") on path between StateNode and Runnable"); } } } } // } } // put calc nodes in partial order for (int i = 0; i < calcNodes.size(); i++) { CalculationNode node = calcNodes.get(i); List<BEASTInterface> inputList = node.listActiveBEASTObjects(); for (int j = calcNodes.size() - 1; j > i; j--) { if (inputList.contains(calcNodes.get(j))) { // swap final CalculationNode node2 = calcNodes.get(j); calcNodes.set(j, node); calcNodes.set(i, node2); j = 0; i--; } } } return calcNodes; } // calculateCalcNodePath public double robustlyCalcPosterior(final Distribution posterior) { store(-1); setEverythingDirty(true); //state.storeCalculationNodes(); checkCalculationNodesDirtiness(); final double logLikelihood = posterior.calculateLogP(); setEverythingDirty(false); acceptCalculationNodes(); return logLikelihood; } public double robustlyCalcNonStochasticPosterior(Distribution posterior) { store(-1); setEverythingDirty(true); storeCalculationNodes(); checkCalculationNodesDirtiness(); final double logLikelihood = posterior.getNonStochasticLogP(); setEverythingDirty(false); acceptCalculationNodes(); return logLikelihood; } } // class State