/**
* Copyright (c) 2011 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation.
*/
package org.eclipse.recommenders.jayes.sampling;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.util.BayesNodeUtil;
public class BasicSampler implements ISampler {
private List<BayesNode> topologicallySortedNodes;
private Map<BayesNode, String> evidence = Collections.emptyMap();
private final Random random = new Random();
@Override
public Map<BayesNode, String> sample() {
Map<BayesNode, String> result = new HashMap<BayesNode, String>();
result.putAll(evidence);
for (BayesNode n : topologicallySortedNodes) {
if (!evidence.containsKey(n)) {
int newEvidence = sampleOutcome(n, result);
result.put(n, n.getOutcomeName(newEvidence));
}
}
return result;
}
private int sampleOutcome(BayesNode node, Map<BayesNode, String> currentSample) {
double[] probs = BayesNodeUtil.getSubCpt(node, currentSample);
double currentProb = 0;
int newEvidence = 0;
double rand = random.nextDouble();
for (double prob : probs) {
currentProb += prob;
if (rand < currentProb) {
break;
}
newEvidence++;
}
return Math.min(newEvidence, node.getOutcomeCount() - 1);
}
@Override
public void setNetwork(BayesNet net) {
topologicallySortedNodes = topsort(net.getNodes());
}
private List<BayesNode> topsort(List<BayesNode> list) {
List<BayesNode> result = new LinkedList<BayesNode>();
Set<BayesNode> visited = new HashSet<BayesNode>();
for (BayesNode n : list) {
depthFirstSearch(n, visited, result);
}
Collections.reverse(result);
return result;
}
private void depthFirstSearch(BayesNode n, Set<BayesNode> visited, List<BayesNode> finished) {
if (visited.contains(n)) {
return;
}
visited.add(n);
for (BayesNode c : n.getChildren()) {
depthFirstSearch(c, visited, finished);
}
finished.add(n);
}
@Override
public void setEvidence(Map<BayesNode, String> evidence) {
this.evidence = evidence;
}
@Override
public void seed(long seed) {
random.setSeed(seed);
}
@Override
@Deprecated
public void setBN(BayesNet net) {
setNetwork(net);
}
}