package org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.CountsNotCompleteException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.BayesLearner;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.Candidate;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.ConditionalProbabilityTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.CountTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.RandomVariable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.SearchConsumer;
/**
* Main class.
*/
public class BayesLearnerImpl implements SearchConsumer, BayesLearner{
private static Logger logger = LoggerFactory.getLogger(BayesLearnerImpl.class);
private BayesianProbabilitiesEstimator learningEngine;
private BayesianNetworkCandidate bestCandidate;
public BayesLearnerImpl() {
this.learningEngine = new SimpleCountingLearningEngine();
}
public BayesLearnerImpl(BayesianProbabilitiesEstimator alt) {
this.learningEngine = alt;
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.impl.BayesLearner#runLearning(int, java.lang.String)
*/
public Map<RandomVariable, ConditionalProbabilityTable> runLearning(int millisecs, String data){
Map<String, RandomVariable> rvMap = new HashMap<String, RandomVariable>();
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement.computeFromData(rvMap, data);
Set<RandomVariable>allnodesset = new HashSet<RandomVariable>();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
int defaultmaxNumberParentsPerNode = allnodesset.size();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
logger.debug("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, defaultmaxNumberParentsPerNode);
return runLearning(millisecs, starterBN, defaultmaxNumberParentsPerNode);
}
/* (non-Javadoc)
* @see de.kl.kn.bayesianLibrary.impl.BayesLearner#runLearning(int, de.kl.kn.searchLibrary.greedySearch.interfaces.Candidate)
*/
public Map<RandomVariable, ConditionalProbabilityTable> runLearning(int millisecs, Candidate startingPoint, int maxNumberParentsPerNode){
if (startingPoint==null || !(startingPoint instanceof BayesianNetworkCandidate)) return null;
BayesianNetworkCandidatesGenerator bcg = new BayesianNetworkCandidatesGenerator(this.learningEngine,
((BayesianNetworkCandidate)startingPoint).getNodes(), maxNumberParentsPerNode);
bcg.initialise(startingPoint);
BasicGreedyHillClimber bghc = new BasicGreedyHillClimber(bcg, this);
bghc.startSearch();
try {
Thread.sleep(millisecs);
} catch (InterruptedException e) {
e.printStackTrace();
}
bghc.stopSearch();
logger.debug("Stop! at: "+System.currentTimeMillis());
if (this.bestCandidate != null) {
Map<?, ?> segments = this.bestCandidate.getSegments();
Map<RandomVariable, ConditionalProbabilityTable> finalNW = new HashMap<RandomVariable, ConditionalProbabilityTable>();
Iterator<?> it = segments.keySet().iterator();
while (it.hasNext()) {
RandomVariable rv = (RandomVariable) it.next();
BNSegment seg = (BNSegment) segments.get(rv);
CountTable ct = this.learningEngine.getCounts(rv, seg
.getOrderedParents());
logger.debug(""+ct);
try {
ConditionalProbabilityTable cpt = this.learningEngine
.getCPT(rv, seg.getOrderedParents(),
this.learningEngine.getUniformPriors(
this.bestCandidate.getN_equiv(),
rv, seg.getOrderedParents()));
finalNW.put(rv, cpt);
} catch (CountsNotCompleteException e) {
e.printStackTrace();
}
}
return finalNW;
}
return null;
}
public Map<RandomVariable, ConditionalProbabilityTable> runLearning(int millisecs, Candidate startingPoint) {
return runLearning(millisecs, startingPoint,startingPoint.candidateSize());
}
public void notifyNewSearchResult(Candidate newCandidate,
double oldBestscore, boolean stoppedExternally, int counter,
long genCounter, long randomRestartsCounter, boolean isAbsoluteBest,
boolean foundSignificantlyBetter) {
String filename = null;
if (isAbsoluteBest) {
filename = "."+java.io.File.separator+"bestfound.txt";
BasicGreedyHillClimber.updateResultFiles(filename, oldBestscore,
newCandidate, counter, genCounter, randomRestartsCounter,
true);
this.bestCandidate = (BayesianNetworkCandidate) newCandidate
.cloneCandidate();
}
filename = "."+java.io.File.separator+"closefound.txt";
BasicGreedyHillClimber.updateResultFiles(filename, oldBestscore,
newCandidate, counter, genCounter, randomRestartsCounter,
!foundSignificantlyBetter);
}
public BayesianProbabilitiesEstimator getLearningEngine() {
return this.learningEngine;
}
public Map<RandomVariable, ConditionalProbabilityTable> learnParametersOnly(BayesianNetworkCandidate bestCandidate) {
logger.info("BestCandidate = "+bestCandidate);
if (bestCandidate == null) return null;
double logScore = bestCandidate.computeFitness();
logger.info("fitness: "+logScore);
Map<RandomVariable, ConditionalProbabilityTable>finalNW = null;
Map<RandomVariable, BNSegment> segments = bestCandidate.getSegments();
finalNW = new HashMap<RandomVariable, ConditionalProbabilityTable>();
Iterator<RandomVariable> it = segments.keySet().iterator();
logger.debug("Is RV iterator non-empty (should be the case)? "+it.hasNext());
while (it.hasNext()) {
RandomVariable rv = it.next();
BNSegment seg = (BNSegment) segments.get(rv);
CountTable ct = this.learningEngine.getCounts(rv, seg
.getOrderedParents());
logger.debug(""+ct);
try {
ConditionalProbabilityTable cpt = this.learningEngine
.getCPT(rv, seg.getOrderedParents(),
this.learningEngine.getUniformPriors(
bestCandidate.getN_equiv(),
rv, seg.getOrderedParents()));
finalNW.put(rv, cpt);
} catch (CountsNotCompleteException e) {
e.printStackTrace();
}
}
return finalNW;
}
}