package org.societies.context.user.refinement.test;
import java.io.File;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.exceptions.CountsNotCompleteException;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.BNSegment;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.BasicGreedyHillClimber;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.BayesianNetworkCandidate;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.BayesianNetworkCandidatesGenerator;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.SimpleCountingLearningEngine;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.impl.SimpleJointMeasurement;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.BayesianLearningClient;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.BayesianProbabilitiesEstimator;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.Candidate;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.ConditionalProbabilityTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.CountTable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.RandomVariable;
import org.societies.context.user.refinement.impl.bayesianLibrary.bayesianLearner.interfaces.SearchConsumer;
import org.societies.context.user.refinement.impl.bayesianLibrary.inference.structures.impl.DAG;
import org.societies.context.user.refinement.impl.tools.NetworkConverter;
/**
* Main class.
*/
public class ChainTester implements SearchConsumer {
private BayesianProbabilitiesEstimator learningEngine;
private BayesianNetworkCandidate startingPoint;
private BayesianNetworkCandidate bestCandidate;
private BayesianLearningClient learningClient;
private boolean debug=false;
public ChainTester(BayesianLearningClient learningClient) {
this.learningEngine = new SimpleCountingLearningEngine();
this.learningClient = learningClient;
}
/* only relevant for BayesianPreferences
public void setLearningData(List list){
Map<String, RandomVariable> rvMap = new HashMap<String, RandomVariable>();
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement.computeFromHistory(rvMap, list);
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset);
startingPoint = starterBN;
}
*/
public void learnFromFile(String filename, int maxNumberParentsPerNode) {
learnFromFile(filename,60000, maxNumberParentsPerNode);
}
public Map learnFromFile(String filename, int milliseconds, int maxNumberParentsPerNode) {
Map<String, RandomVariable> rvMap = new HashMap<String, RandomVariable>();
//BayesianNetworkCandidate starterBN = importKidsData(rvMap);
//BayesianNetworkCandidate starterBN = importCmcdataData(rvMap);
//BayesianNetworkCandidate starterBN = importSprinklerData(rvMap);
//BayesianNetworkCandidate starterBN = importHousingData(rvMap);
//BayesianNetworkCandidate starterBN = importActivityLearningData(rvMap);
if(debug) System.out.println(".\\resources\\activityData\\"+filename);
// System.exit(0);
BayesianNetworkCandidate starterBN ;
if (filename==null || filename.equals(""))
starterBN = importData(rvMap,".\\resources\\sarah.txt");
else
starterBN = importData(rvMap,".\\resources\\activityData\\"+filename);
BasicGreedyHillClimber bghc = new BasicGreedyHillClimber(
new BayesianNetworkCandidatesGenerator(this.learningEngine,
starterBN.getNodes(), maxNumberParentsPerNode), this);
bghc.startSearch();
try {
Thread.sleep(milliseconds);
} catch (InterruptedException e) {
e.printStackTrace();
}
bghc.stopSearch();
if(debug) System.out.println("BGHC STOPPED!");
Map<RandomVariable, ConditionalProbabilityTable>finalNW = null;
if (this.bestCandidate != null) {
Map segments = this.bestCandidate.getSegments();
finalNW = new HashMap<RandomVariable, ConditionalProbabilityTable>();
Iterator it = segments.keySet().iterator();
while (it.hasNext()) {
RandomVariable rv = (RandomVariable) it.next();
BNSegment seg = (BNSegment) segments.get(rv);
CountTable ct = this.learningEngine.getCounts(rv, seg
.getOrderedParents());
if(debug) System.out.println(ct);
try {
ConditionalProbabilityTable cpt = this.learningEngine
.getCPT(rv, seg.getOrderedParents(),
this.learningEngine.getUniformPriors(
this.bestCandidate.getN_equiv(),
rv, seg.getOrderedParents()));
finalNW.put(rv, cpt);
/* Testing
if (rv.getName().equals("dow")) {
// RandomVariable appsRV = (RandomVariable) rvMap.get("apps");
RandomVariable volumeRV = (RandomVariable) rvMap.get("loc");
try {
// SimpleInstantiatedRV appsSIRV = new SimpleInstantiatedRV(appsRV, false, appsRV.getNodeValueFromText("pim"));
SimpleInstantiatedRV volumeSIRV = new SimpleInstantiatedRV(volumeRV, false, volumeRV.getNodeValueFromText("home"));
Set parents = new HashSet();
// parents.add(appsSIRV);
parents.add(volumeSIRV);
double cptvalue = cpt.getProbability(parents, rv.getNodeValueFromText("Fri"));
System.out.println("\n\n\n\n\n\n\nProbability for parents: " + //appsSIRV +
" UUUUU " + volumeSIRV + " : " + cptvalue+"\n\n\n\n\n\n\n");
// appsSIRV = new SimpleInstantiatedRV(appsRV, false, appsRV.getNodeValueFromText("media"));
volumeSIRV = new SimpleInstantiatedRV(volumeRV, false, volumeRV.getNodeValueFromText("canteen"));
parents = new HashSet();
// parents.add(appsSIRV);
parents.add(volumeSIRV);
cptvalue = cpt.getProbability(parents, rv.getNodeValueFromText("Sat"));
System.out.println("\n\n\n\n\n\n\nProbability for parents: " + //appsSIRV +
" UUUUU " + volumeSIRV + " : " + cptvalue+"\n\n\n\n\n\n\n");
// appsSIRV = new SimpleInstantiatedRV(appsRV, false, appsRV.getNodeValueFromText("pim"));
volumeSIRV = new SimpleInstantiatedRV(volumeRV, false, volumeRV.getNodeValueFromText("canteen"));
parents = new HashSet();
// parents.add(appsSIRV);
parents.add(volumeSIRV);
cptvalue = cpt.getProbability(parents, rv.getNodeValueFromText("Mo-Th"));
System.out.println("\n\n\n\n\n\n\nProbability for parents: " + //appsSIRV +
" UUUUU " + volumeSIRV + " : " + cptvalue+"\n\n\n\n\n\n\n");
// appsSIRV = new SimpleInstantiatedRV(appsRV, false, appsRV.getNodeValueFromText("none"));
volumeSIRV = new SimpleInstantiatedRV(volumeRV, false, volumeRV.getNodeValueFromText("office"));
parents = new HashSet();
// parents.add(appsSIRV);
parents.add(volumeSIRV);
cptvalue = cpt.getProbability(parents, rv.getNodeValueFromText("Mo-Th"));
System.out.println("\n\n\n\n\n\n\nProbability for parents: " + //appsSIRV +
" UUUUU " + volumeSIRV + " : " + cptvalue+"\n\n\n\n\n\n\n");
} catch (NodeValueIndexNotInNodeRangeException e) {
e.printStackTrace();
} catch (NodeValueTextNotInNodeRangeException e) {
e.printStackTrace();
} catch (ParentsNotContainedException e) {
e.printStackTrace();
} catch (PriorAndCountTablesMismatchException e) {
e.printStackTrace();
} catch (RVNotInstantiatedException e) {
e.printStackTrace();
}
}
/**/
} catch (CountsNotCompleteException e) {
e.printStackTrace();
}
}
if(debug) System.out.println("vor setNetwork, finalNW="+finalNW);
this.learningClient.setNetwork("Tester BN", finalNW);
if(debug) System.out.println("about to finish");
}
return finalNW;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importCmcdataData(Map rvMap) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, ".\\resources\\cmcdata.txt");
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importActivityLearningData(Map rvMap) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, ".\\resources\\activityLearning simulation.txt");
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importHousingData(Map rvMap) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, ".\\resources\\housingdata.txt");
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importData(Map rvMap, String filename) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, filename);
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importDataFiles(Map rvMap, File[] files) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFiles(rvMap, files);
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importSprinklerData(Map rvMap) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, ".\\resources\\b-courseSprinkler1.txt");
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
RandomVariable weather = (RandomVariable) rvMap.get("weather");
RandomVariable grass = (RandomVariable) rvMap.get("grass");
RandomVariable neighbourGrass = (RandomVariable) rvMap.get("ng");
RandomVariable mysprinkler = (RandomVariable) rvMap.get("mysprinkler");
// starterBN.addArc(grass, weather);
// starterBN.addArc(grass, mysprinkler);
// starterBN.addArc(neighbourGrass, weather);
return starterBN;
}
/**
* @param rvMap
* @return
*/
private BayesianNetworkCandidate importKidsData(Map rvMap) {
SimpleJointMeasurement[] sjmFile = SimpleJointMeasurement
.computeFromDataFile(rvMap, ".\\resources\\popkidsdata.txt");
Set<RandomVariable>allnodesset = new HashSet();
allnodesset.clear();
this.learningEngine.resetTables();
this.learningEngine.clearMeasurements();
allnodesset.addAll(rvMap.values());
for (int r = 0; r < sjmFile.length; r++) {
// System.out.println("Adding sjm to Learning engine: " + sjmFile[r]);
this.learningEngine.addMeasurement(sjmFile[r]);
}
BayesianNetworkCandidate starterBN = new BayesianNetworkCandidate(
this.learningEngine, allnodesset, allnodesset.size());
RandomVariable Goals = (RandomVariable) rvMap.get("Goals");
RandomVariable Grades = (RandomVariable) rvMap.get("Grades");
RandomVariable Grade = (RandomVariable) rvMap.get("Grade");
RandomVariable Race = (RandomVariable) rvMap.get("Race");
RandomVariable Gender = (RandomVariable) rvMap.get("Gender");
RandomVariable Looks = (RandomVariable) rvMap.get("Looks");
RandomVariable School = (RandomVariable) rvMap.get("School");
RandomVariable Sports = (RandomVariable) rvMap.get("Sports");
RandomVariable Urban_Rural = (RandomVariable) rvMap.get("Urban/Rural");
RandomVariable Age = (RandomVariable) rvMap.get("Age");
RandomVariable Money = (RandomVariable) rvMap.get("Money");
// starterBN.addArc(Race, Urban_Rural);
// starterBN.addArc(Urban_Rural, School);
// starterBN.addArc(School, Grade);
// starterBN.addArc(Age, Grade);
//
// starterBN.addArc(Grade, Grades);
// starterBN.addArc(Grades, Money);
// starterBN.addArc(Looks, Grades);
// starterBN.addArc(Looks, Money);
// starterBN.addArc(Gender, Sports);
// starterBN.addArc(Sports, Grades);
// starterBN.addArc(Sports, Looks);
// starterBN.addArc(Sports, Money);
// starterBN.addArc(Goals, Gender);
return starterBN;
}
public void notifyNewSearchResult(Candidate newCandidate,
double oldBestscore, boolean stoppedExternally, int counter,
long genCounter, long randomRestartsCounter, boolean isAbsoluteBest,
boolean foundSignificantlyBetter) {
String filename = null;
if (isAbsoluteBest) {
filename = ".\\resources\\bestfound.txt";
BasicGreedyHillClimber.updateResultFiles(filename, oldBestscore,
newCandidate, counter, genCounter, randomRestartsCounter,
true);
this.bestCandidate = (BayesianNetworkCandidate) newCandidate
.cloneCandidate();
}
filename = ".\\resources\\closefound.txt";
BasicGreedyHillClimber.updateResultFiles(filename, oldBestscore,
newCandidate, counter, genCounter, randomRestartsCounter,
!foundSignificantlyBetter);
}
public DAG learnDAGFromFiles(File[] files, int milliseconds, boolean naiveBayes, int maxNumberParentsPerNode) {
Map<String, RandomVariable>rvMap = new HashMap<String, RandomVariable>();
if(debug) if (files==null) System.out.println("files==null"); else System.out.println("number of files="+files.length);
BayesianNetworkCandidate starterBN ;
if (files==null || files.length==0)
starterBN = importDataFiles(rvMap,new File[]{new File(".\\resources\\sarah.txt")});
else
starterBN = importDataFiles(rvMap,files);
if (naiveBayes){
String cause = "activity";
RandomVariable causeRV = rvMap.get(cause);
for(RandomVariable effect:rvMap.values()){
if (effect!=causeRV) starterBN.addArc(effect, causeRV);
}
double fitness = starterBN.computeFitness();
boolean foundAbsoluteBest = true;
boolean foundSignificantBetter = true;
boolean stoppedExternally = false;
notifyNewSearchResult(starterBN, 0, stoppedExternally,
1, 1, 0,
foundAbsoluteBest, foundSignificantBetter);
}
else{
BasicGreedyHillClimber bghc = new BasicGreedyHillClimber(
new BayesianNetworkCandidatesGenerator(this.learningEngine,
starterBN.getNodes(), maxNumberParentsPerNode), this);
bghc.startSearch();
try {
Thread.sleep(milliseconds);
} catch (InterruptedException e) {
e.printStackTrace();
}
bghc.stopSearch();
if(debug) System.out.println("BGHC STOPPED!");
}
Map<RandomVariable, ConditionalProbabilityTable>finalNW = null;
DAG newNetwork = null;
if (this.bestCandidate != null) {
Map segments = this.bestCandidate.getSegments();
finalNW = new HashMap<RandomVariable, ConditionalProbabilityTable>();
Iterator it = segments.keySet().iterator();
while (it.hasNext()) {
RandomVariable rv = (RandomVariable) it.next();
BNSegment seg = (BNSegment) segments.get(rv);
CountTable ct = this.learningEngine.getCounts(rv, seg
.getOrderedParents());
if(debug) System.out.println(ct);
try {
ConditionalProbabilityTable cpt = this.learningEngine
.getCPT(rv, seg.getOrderedParents(),
this.learningEngine.getUniformPriors(
this.bestCandidate.getN_equiv(),
rv, seg.getOrderedParents()));
finalNW.put(rv, cpt);
} catch (CountsNotCompleteException e) {
e.printStackTrace();
}
}
newNetwork = NetworkConverter.convertStructures(finalNW);
}
return newNetwork;
}
}