package com.alibaba.simpleimage.analyze.search.tree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.alibaba.simpleimage.analyze.search.cluster.ClusterBuilder;
import com.alibaba.simpleimage.analyze.search.cluster.Clusterable;
import com.alibaba.simpleimage.analyze.search.cluster.impl.Cluster;
import com.alibaba.simpleimage.analyze.search.util.ClusterUtils;
import com.alibaba.simpleimage.analyze.search.util.TreeUtils;
public class KMeansTreeNode implements Clusterable, Serializable {
private static final long serialVersionUID = 1L;
private List<KMeansTreeNode> subNodes;
private boolean isLeafNode = false;
private float[] center;// The center of the item
private int height = 0;// The depth of the node from root
private int numSubItems;// Total number of items with a path through this
// node, or the "weight"
private int currentItems;// The current number of items with a path through
// this node
private int id = -1;// The unique id for the node in the tree, AKA the
// "word" of the tree
public KMeansTreeNode(float[] center, List<Clusterable> items,
int branchFactor, int maxHeight, int height, ClusterBuilder clusterBuilder) {
// TODO: Something about this global variable
if (height == maxHeight || items.size() < branchFactor
|| (getMeanDist(items, center) < 0)) {
isLeafNode = true;
subNodes = new ArrayList<KMeansTreeNode>(0);
id = KMeansTree.idCount++;
}
else {
Clusterable[] clusters = clusterBuilder.collect(items, branchFactor);
subNodes = new ArrayList<KMeansTreeNode>(branchFactor);
for (Clusterable cluster : clusters) {
if(cluster instanceof Cluster)
if (((Cluster)cluster).getItems().size() > 0) {
KMeansTreeNode node = new KMeansTreeNode(
((Cluster)cluster).getClusterMean(), ((Cluster)cluster).getItems(),
branchFactor, maxHeight, height + 1, clusterBuilder);
subNodes.add(node);
}
}
}
this.height = height;
this.center = center;
this.numSubItems = items.size();
}
private float getMeanDist(List<Clusterable> items, float[] center) {
float sum = 0;
for (Clusterable clusterItem : items) {
float dist = ClusterUtils.getEuclideanDistance(
clusterItem.getLocation(), center);
sum += dist;
}
return sum / items.size();
}
public boolean isLeafNode() {
return isLeafNode;
}
public List<KMeansTreeNode> getSubNodes() {
return subNodes;
}
public float[] getLocation() {
return center;
}
public int getNumSubItems() {
return numSubItems;
}
public int getHeight() {
return height;
}
public int getId() {
return id;
}
/**
* Adds a clusterable to the current vocab tree for word creation
*/
public int getValueId(Clusterable c) {
currentItems++;
/*
* if(isLeafNode()) { return id; }
*/
int index = TreeUtils.findNearestNodeIndex(subNodes, c);
if (index >= 0) {
KMeansTreeNode node = subNodes.get(index);
return node.getValueId(c);
}
return id;
}
public int getCurrentItemCount() {
return currentItems;
}
public void reset() {
currentItems = 0;
for (KMeansTreeNode node : subNodes) {
node.reset();
}
}
@Override
public String toString() {
return "KMeansTreeNode [isLeafNode=" + isLeafNode + ", center="
+ Arrays.toString(center) + ", height=" + height
+ ", numSubItems=" + numSubItems + ", currentItems="
+ currentItems + ", id=" + id + "]";
}
}