/**
*
*/
package rampancy.util.data.segmentTree;
import java.util.Arrays;
import rampancy.util.*;
import rampancy.util.data.RSegmentFunction;
/**
* @author Matthew Chun-Lum
*
*/
public class RSTNode {
public static void updateGuessFactors(RSTNode node, int guessFactorIndex, double weight, int rollDepth) {
node.noteVisit();
int visits = node.getVisitCount();
double[] guessFactors = node.getGuessFactors();
guessFactors[guessFactorIndex] = RUtil.rollingAvg(guessFactors[guessFactorIndex], 1.0, Math.min(visits, rollDepth), weight);
for(int i = 0; i < guessFactors.length; i++)
if(i != guessFactorIndex)
guessFactors[i] = RUtil.rollingAvg(guessFactors[i], 1.0 / (Math.pow(guessFactorIndex - i, 2) + 1.0), Math.min(visits, rollDepth), weight);
}
private RSegmentTree rootReference;
private RSegmentFunction segmentFunction;
private RSTNode[] branches;
private double[] guessFactors;
private boolean hasBranched;
private int segmentSize;
private int visits;
private int depth;
/**
* Use this constructor
* @param segmentFunction
* @param segmentSize
*/
public RSTNode(RSegmentFunction segmentFunction, int segmentSize) {
this.segmentFunction = segmentFunction;
this.segmentSize = segmentSize;
}
/**
* This constructor is only used internally
* @param segmentFunction
* @param segmentSize
* @param seedFactors
* @param rootReference
*/
public RSTNode(RSegmentFunction segmentFunction, int segmentSize, double[] seedFactors, int depth, int visits, RSegmentTree rootReference) {
this.rootReference = rootReference;
this.segmentFunction = segmentFunction;
branches = new RSTNode[segmentSize];
this.segmentSize = segmentSize;
guessFactors = Arrays.copyOf(seedFactors, seedFactors.length);
this.visits = visits;
this.depth = depth;
}
public RSTNode newInstance(double[] seedFactors, int depth, RSegmentTree reference) {
return new RSTNode(segmentFunction, segmentSize, seedFactors, depth, visits, reference);
}
public RSTNode getSegmentForState(RRobotState state) {
if(hasBranched()) {
int index = segmentFunction.getIndexForState(state, segmentSize);
if(branches[index] == null) {
RSTNode template = rootReference.getTemplateNodeForDepth(depth + 1);
branches[index] = template.newInstance(guessFactors, depth + 1, rootReference);
}
return branches[index].getSegmentForState(state);
}
return this;
}
public int getSegmentSize() {
return segmentSize;
}
public void setBranched() {
hasBranched = true;
}
public boolean hasBranched() {
if(!hasBranched)
hasBranched = (segmentFunction != null && visits >= rootReference.getVisitsBeforeBranch());
return hasBranched;
}
public void noteVisit() {
visits++;
}
public int getVisitCount() {
return visits;
}
public double[] getGuessFactors() {
return guessFactors;
}
public int getBranchCount() {
int count = 1;
for(int i = 0; i < branches.length; i++) {
if(branches[i] != null)
count += branches[i].getBranchCount();
}
return count;
}
public int getTerminalBranchCount(int maxDepth) {
if(depth == maxDepth) {
return 1;
}
int count = 0;
for(int i = 0; i < branches.length; i++) {
if(branches[i] != null)
count += branches[i].getTerminalBranchCount(maxDepth);
}
return count;
}
}