/** * KMeansBSP.java */ package com.chinamobile.bcbsp.examples.kmeans; import java.util.ArrayList; import java.util.Iterator; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.chinamobile.bcbsp.api.BSP; import com.chinamobile.bcbsp.api.Edge; import com.chinamobile.bcbsp.bspstaff.BSPStaffContextInterface; import com.chinamobile.bcbsp.bspstaff.SuperStepContextInterface; import com.chinamobile.bcbsp.comm.BSPMessage; import com.chinamobile.bcbsp.util.BSPJob; /** * KMeansBSP * This is the user-defined arithmetic which implements {@link BSP}. * Implements the basic k-means algorithm. * * @author Bai Qiushi * @version 0.2 2012-2-28 */ public class KMeansBSP extends BSP { public static final Log LOG = LogFactory.getLog(KMeansBSP.class); public static final String KMEANS_K = "kmeans.k"; public static final String KMEANS_CENTERS = "kmeans.centers"; public static final String AGGREGATE_KCENTERS = "aggregate.kcenters"; private BSPJob jobconf; private int superStepCount; private int K; private int dimension; private ArrayList<ArrayList<Float>> kCenters = new ArrayList<ArrayList<Float>>(); //The threshold for average error between the new k centers and the last k centers. private final double errors_Threshold = 0.01; //The real average error between the new k centers and the last k centers. private double errors = Double.MAX_VALUE; @SuppressWarnings("unchecked") @Override public void compute(Iterator<BSPMessage> messages, BSPStaffContextInterface context) throws Exception { jobconf = context.getJobConf(); superStepCount = context.getCurrentSuperStepCounter(); ArrayList<Float> thisPoint = new ArrayList<Float>(); KMVertex thisVertex = (KMVertex) context.getVertex(); Iterator<Edge> outgoingEdges = context.getOutgoingEdges(); //Init this point while (outgoingEdges.hasNext()) { KMEdge edge = (KMEdge) outgoingEdges.next(); thisPoint.add( Float.valueOf(edge.getEdgeValue()) ); } //Calculate the class tag of this vertex. byte tag = 0; double minDistance = Double.MAX_VALUE; //Find the shortest distance of this point with the kCenters. for (byte i = 0; i < kCenters.size(); i ++) { ArrayList<Float> center = kCenters.get(i); double dist = distanceOf(thisPoint, center); if (dist < minDistance) { tag = i; minDistance = dist; } } //Write the vertex's class tag into the vertex value. thisVertex.setVertexValue(tag); context.updateVertex(thisVertex); if (this.errors < this.errors_Threshold) { context.voltToHalt(); } }//end-compute private double distanceOf(ArrayList<Float> p1, ArrayList<Float> p2) { double dist = 0.0; // dist = (x1-y1)^2 + (x2-y2)^2 + ... + (xn-yn)^2 for (int i = 0; i < p1.size(); i ++) { dist = dist + (p1.get(i) - p2.get(i)) * (p1.get(i) - p2.get(i)); } dist = Math.sqrt(dist); return dist; } @Override public void initBeforeSuperStep(SuperStepContextInterface context) { this.superStepCount = context.getCurrentSuperStepCounter(); jobconf = context.getJobConf(); if (superStepCount == 0) { this.K = Integer.valueOf(jobconf.get(KMeansBSP.KMEANS_K)); //Init the k original centers from job conf. String originalCenters = jobconf.get(KMeansBSP.KMEANS_CENTERS); String[] centers = originalCenters.split("\\|"); for (int i = 0; i < centers.length; i ++) { ArrayList<Float> center = new ArrayList<Float>(); String[] values = centers[i].split("-"); for (int j = 0; j < values.length; j ++) { center.add(Float.valueOf(values[j])); } kCenters.add(center); } this.dimension = kCenters.get(0).size(); LOG.info("[KMeansBSP] K = " + K); LOG.info("[KMeansBSP] dimension = " + dimension); LOG.info("[KMeansBSP] k centers: "); for (int i = 0; i < K; i ++) { String tmpCenter = ""; for (int j = 0; j < dimension; j ++) { tmpCenter = tmpCenter + " " + kCenters.get(i).get(j); } LOG.info("[KMeansBSP] <" + tmpCenter + " >"); } } else { KCentersAggregateValue kCentersAgg = (KCentersAggregateValue) context.getAggregateValue(KMeansBSP.AGGREGATE_KCENTERS); ArrayList<ArrayList<Float>> newKCenters = new ArrayList<ArrayList<Float>>(); //Calculate the new k centers and save them to newKCenters. ArrayList<ArrayList<Float>> contents = kCentersAgg.getValue(); ArrayList<Float> nums = contents.get(K); for (int i = 0; i < K; i ++) { ArrayList<Float> center = new ArrayList<Float>(); //Get the sum of coordinates of points in class i. ArrayList<Float> sum = contents.get(i); //Get the number of points in class i. float num = nums.get(i); for (int j = 0; j < dimension; j ++) { //the center's coordinate value. center.add(sum.get(j)/num); } //The i center. newKCenters.add(center); } this.errors = 0.0; //Calculate the errors sum between the new k centers and the last k centers. for (int i = 0; i < K; i ++) { for (int j = 0; j < dimension; j ++) { this.errors = this.errors + Math.abs(kCenters.get(i).get(j) - newKCenters.get(i).get(j)); } } this.errors = this.errors / (K * dimension); this.kCenters.clear(); this.kCenters = newKCenters; LOG.info("[KMeansBSP] k centers: "); for (int i = 0; i < K; i ++) { String tmpCenter = "[" + nums.get(i) + "]"; for (int j = 0; j < dimension; j ++) { tmpCenter = tmpCenter + " " + kCenters.get(i).get(j); } LOG.info("[KMeansBSP] <" + tmpCenter + " >"); } } LOG.info("[KMeansBSP]******* Error = " + errors + " ********"); } }