package com.github.projectflink.hadoop;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.*;
/*
* Code adapted from https://github.com/thomasjungblut/thomasjungblut-common/
*/
public class KMeansDriver implements Tool {
private static final Log LOG = LogFactory.getLog(KMeansDriver.class);
public static String CENTERS_CONF_KEY = "centroids.path";
private Configuration conf;
@Override
public void setConf(Configuration configuration) {
this.conf = configuration;
}
@Override
public Configuration getConf() {
if(this.conf == null) {
this.conf = new Configuration();
}
return this.conf;
}
public static class Point implements Writable{
double [] coord;
int n;
public Point() {
this.n = 0;
this.n = 0;
this.coord = null;
}
public Point(double[] coord) {
this.coord = coord;
this.n = coord.length;
}
public Point (Point p) {
this.coord = p.coord;
this.n = p.n;
}
public Point add (Point other) {
for (int i = 0; i < n; i++)
coord[i] += other.coord[i];
return this;
}
public Point div (double d) {
for (int i = 0; i < n; i++)
coord[i] /= d;
return this;
}
@Override
public void write(DataOutput dataOutput) throws IOException {
dataOutput.writeInt(n);
for (int i = 0; i < n; i++)
dataOutput.writeDouble(coord[i]);
}
@Override
public void readFields(DataInput dataInput) throws IOException {
this.n = dataInput.readInt();
this.coord = new double[n];
for (int i = 0; i < n; i++)
this.coord[i] = dataInput.readDouble();
}
public double euclideanDistance (Point other) {
double sumOfSquares = 0.0;
for (int i = 0; i < n; i++) {
sumOfSquares += (coord[i] - other.coord[i]) * (coord[i] - other.coord[i]);
}
return Math.sqrt(sumOfSquares);
}
public Point deepCopy () {
Point out = new Point ();
out.n = this.n;
out.coord = new double[out.n];
for (int i = 0; i < out.n; i++) {
out.coord[i] = this.coord[i];
}
return out;
}
@Override
public String toString() {
//String out = "[";
String out = "";
for (int i = 0; i < n; i++) {
out += Double.valueOf(coord[i]);
if (i < n-1) {
out += " ";
}
}
return out;
}
}
public static class Centroid extends Point implements WritableComparable<Centroid>{
private int id;
public Centroid() {
}
public Centroid(Point p) {
super(p);
this.id = -1;
}
public Centroid (Centroid c) {
super(c);
this.id = c.id;
}
public Centroid (int id, Point p) {
super(p);
this.id = id;
}
@Override
public void write(DataOutput dataOutput) throws IOException {
super.write(dataOutput);
dataOutput.writeInt(id);
}
@Override
public void readFields(DataInput dataInput) throws IOException {
super.readFields(dataInput);
this.id = dataInput.readInt();
}
@Override
public String toString() {
return Integer.toString(id);
}
@Override
public int compareTo(Centroid o) {
return Integer.compare(this.id, o.id);
}
}
public static class KMeansMapper extends Mapper<Centroid, Point, Centroid, Point> {
private final List<Centroid> centers = new ArrayList<Centroid>();
@Override
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
Path centroids = new Path(conf.get(CENTERS_CONF_KEY));
FileSystem fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, centroids, conf);
Centroid key = new Centroid();
IntWritable value = new IntWritable();
while (reader.next(key, value)) {
Centroid clusterCenter = new Centroid(key);
centers.add(clusterCenter);
}
//LOG.info("Centroid list in Mapper: " + centers.toString());
reader.close();
}
@Override
protected void map(Centroid key, Point value, Context context) throws IOException, InterruptedException {
Centroid nearest = null;
double nearestDistance = Double.MAX_VALUE;
for (Centroid c : centers) {
double dist = value.euclideanDistance(c);
if (nearest == null) {
nearest = c;
nearestDistance = dist;
}
else {
if (dist < nearestDistance) {
nearest = c;
nearestDistance = dist;
}
}
}
context.write(nearest, value);
}
}
public static class KMeansCombiner extends Reducer<Centroid, Point, Centroid, Point> {
@Override
protected void reduce(Centroid key, Iterable<Point> values, Context context) throws IOException, InterruptedException {
ArrayList<Point> points = new ArrayList<Point>();
points.clear();
int clusterId = key.id;
Point newCenter = null;
for (Point p : values) {
Point copy = p.deepCopy();
points.add(copy);
if (newCenter == null) {
newCenter = new Point (copy.deepCopy());
}
else {
newCenter = newCenter.add (copy);
}
}
Centroid center = new Centroid(clusterId, newCenter);
for (Point p: points) {
context.write(center, p);
}
}
}
public static class KMeansReducer extends KMeansCombiner {
final List<Centroid> centers = new ArrayList<Centroid>();
@Override
protected void reduce(Centroid key, Iterable<Point> values, Context context) throws IOException, InterruptedException {
ArrayList<Point> points = new ArrayList<Point>();
points.clear();
int clusterId = key.id;
Point newCenter = null;
for (Point p : values) {
Point copy = p.deepCopy();
points.add(copy);
if (newCenter == null) {
newCenter = new Point(copy.deepCopy());
} else {
newCenter = newCenter.add(copy);
}
}
newCenter = newCenter.div(points.size());
Centroid center = new Centroid(clusterId, newCenter);
centers.add(center);
for (Point p : points) {
context.write(center, p);
}
}
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
super.cleanup(context);
Configuration conf = context.getConfiguration();
Path outPath = new Path(conf.get(CENTERS_CONF_KEY));
FileSystem fs = FileSystem.get(conf);
// fs.delete(outPath, true);
SequenceFile.Writer writer = SequenceFile.createWriter(fs, context.getConfiguration(), outPath, Centroid.class, IntWritable.class);
final IntWritable mockValue = new IntWritable(0);
for (Centroid center : centers) {
writer.append(center, mockValue);
}
writer.close();
}
}
public static class CenterInitializer extends Mapper<LongWritable,Text,Centroid,Point> {
private final List<Centroid> centers = new ArrayList<Centroid>();
@Override
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
Path centroids = new Path(conf.get(CENTERS_CONF_KEY));
FileSystem fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, centroids, conf);
Centroid key = new Centroid();
IntWritable value = new IntWritable();
while (reader.next(key, value)) {
Centroid clusterCenter = new Centroid(key);
centers.add(clusterCenter);
}
reader.close();
}
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String line = value.toString();
StringTokenizer tokenizer = new StringTokenizer(line, " ");
int dim = tokenizer.countTokens();
double [] coords = new double[dim];
for (int i = 0; i < dim; i++) {
coords[i] = Double.valueOf(tokenizer.nextToken());
}
Point point = new Point (coords);
Centroid nearest = null;
double nearestDistance = Double.MAX_VALUE;
for (Centroid c : centers) {
double dist = point.euclideanDistance(c);
if (nearest == null) {
nearest = c;
nearestDistance = dist;
}
else {
if (dist < nearestDistance) {
nearest = c;
nearestDistance = dist;
}
}
}
context.write(nearest, point);
}
}
public static class RandomCenterInitializer extends Mapper<LongWritable,Text,Centroid,Point> {
private final List<Centroid> centers = new ArrayList<Centroid>();
private final Random rand = new Random();
@Override
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
Path centroids = new Path(conf.get(CENTERS_CONF_KEY));
FileSystem fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, centroids, conf);
Centroid key = new Centroid();
IntWritable value = new IntWritable();
while (reader.next(key, value)) {
Centroid clusterCenter = new Centroid(key);
centers.add(clusterCenter);
}
reader.close();
}
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String line = value.toString();
StringTokenizer tokenizer = new StringTokenizer(line, " ");
int dim = tokenizer.countTokens();
double [] coords = new double[dim];
for (int i = 0; i < dim; i++) {
coords[i] = Double.valueOf(tokenizer.nextToken());
}
Centroid center = centers.get(rand.nextInt(centers.size()));
Point point = new Point (coords);
context.write(center, point);
}
}
public static class PointSequenceToTextConverter extends Mapper<Centroid, Point, Text, Text> {
@Override
protected void map(Centroid key, Point value, Context context) throws IOException, InterruptedException {
String out = key.toString() + " " + value.toString();
context.write(new Text(out), new Text(""));
}
}
public static class CenterSequenceToTextConverter extends Mapper<Centroid, IntWritable, Text, Text> {
@Override
protected void map(Centroid key, IntWritable value, Context context) throws IOException, InterruptedException {
context.write(new Text(key.toString()), new Text(value.toString()));
}
}
public static void convertPointsSequenceFileToText (Configuration conf, FileSystem fs, String seqFilePath, String outputPath) throws Exception {
Path seqFile = new Path (seqFilePath);
Path output = new Path (outputPath);
if (fs.exists(output)) {
fs.delete(output, true);
}
Job job = Job.getInstance(conf);
job.setMapperClass(PointSequenceToTextConverter.class);
job.setReducerClass(Reducer.class);
job.setNumReduceTasks(0);
job.setMapOutputKeyClass(LongWritable.class);
job.setMapOutputValueClass(Text.class);
job.setOutputKeyClass(LongWritable.class);
job.setOutputValueClass(Text.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
FileInputFormat.addInputPath(job, seqFile);
FileOutputFormat.setOutputPath(job, output);
job.waitForCompletion(true);
}
public static void convertCentersSequenceFileToText (Configuration conf, FileSystem fs, String seqFilePath, String outputPath) throws Exception {
Path seqFile = new Path (seqFilePath);
Path output = new Path (outputPath);
if (fs.exists(output)) {
fs.delete(output, true);
}
Job job = Job.getInstance(conf);
job.setMapperClass(CenterSequenceToTextConverter.class);
job.setReducerClass(Reducer.class);
job.setNumReduceTasks(0);
job.setMapOutputKeyClass(LongWritable.class);
job.setMapOutputValueClass(Text.class);
job.setOutputKeyClass(LongWritable.class);
job.setOutputValueClass(Text.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
FileInputFormat.addInputPath(job, seqFile);
FileOutputFormat.setOutputPath(job, output);
job.waitForCompletion(true);
}
public static void createCentersSequenceFile (Configuration conf, FileSystem fs, String centroidsPath, String sequenceFilePath) throws Exception {
Path seqFile = new Path (sequenceFilePath);
if (fs.exists(seqFile)) {
fs.delete(seqFile, true);
}
FSDataInputStream inputStream = fs.open(new Path(centroidsPath));
SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, seqFile, Centroid.class, IntWritable.class);
IntWritable value = new IntWritable(0);
while (inputStream.available() > 0) {
String line = inputStream.readLine();
StringTokenizer tokenizer = new StringTokenizer(line, " ");
int dim = tokenizer.countTokens() - 1;
int clusterId = Integer.valueOf(tokenizer.nextToken());
double [] coords = new double [dim];
for (int i = 0; i < dim; i++) {
coords[i] = Double.valueOf(tokenizer.nextToken());
}
Centroid cluster = new Centroid(clusterId, new Point(coords));
writer.append(cluster, value);
}
IOUtils.closeStream(writer);
inputStream.close();
}
public static void initializeCenters (Configuration conf, FileSystem fs, String pointsPath, String seqFilePath) throws Exception {
Path points = new Path (pointsPath);
Path seqFile = new Path (seqFilePath);
if (fs.exists(seqFile)) {
fs.delete(seqFile, true);
}
Job job = Job.getInstance(conf);
job.setMapperClass(CenterInitializer.class);
job.setReducerClass(Reducer.class);
job.setNumReduceTasks(0);
job.setMapOutputKeyClass(Centroid.class);
job.setMapOutputValueClass(Point.class);
job.setOutputKeyClass(Centroid.class);
job.setOutputValueClass(Point.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(TextInputFormat.class);
FileInputFormat.addInputPath(job, new Path(pointsPath));
FileOutputFormat.setOutputPath(job, seqFile);
job.waitForCompletion(true);
}
public static void kmeans (Configuration conf, FileSystem fs, String pointsPath, String resultsPath, int maxIterations) throws Exception {
Path points = new Path (pointsPath);
Job job = null;
Path inPath = null;
Path outPath = null;
for (int iteration = 0; iteration < maxIterations; iteration++) {
job = Job.getInstance(conf);
job.setMapperClass(KMeansMapper.class);
job.setReducerClass(KMeansReducer.class);
job.setCombinerClass(KMeansReducer.class); // TODO: think about this,
job.setMapOutputKeyClass(Centroid.class);
job.setMapOutputValueClass(Point.class);
job.setOutputKeyClass(Centroid.class);
job.setOutputValueClass(Point.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
if (iteration == 0) {
inPath = points;
}
else {
inPath = new Path(resultsPath + "_iteration_" + (iteration - 1));
}
if (iteration == maxIterations - 1) {
outPath = new Path(resultsPath);
}
else {
outPath = new Path(resultsPath + "_iteration_" + iteration);
}
if (fs.exists(outPath)) {
fs.delete(outPath, true);
}
FileInputFormat.addInputPath(job, inPath);
FileOutputFormat.setOutputPath(job, outPath);
fs.delete(new Path(conf.get(CENTERS_CONF_KEY)), true);
if (!job.waitForCompletion(true)) {
throw new RuntimeException("K-Means iteration " + iteration + " failed");
}
}
}
@Override
public int run(String [] args) throws Exception {
if(this.conf == null) {
this.conf = new Configuration();
}
String points = args[0];
String centers = args[1];
String result = args[2];
int maxIterations = Integer.valueOf(args[3]);
String centersSeqFile = centers + "_seq";
String pointsSeqFile = points + "_seq";
conf.set(CENTERS_CONF_KEY, centersSeqFile);
FileSystem fs = FileSystem.get(conf);
createCentersSequenceFile(conf, fs, centers, centersSeqFile);
initializeCenters(conf, fs, points, pointsSeqFile);
kmeans (conf, fs, pointsSeqFile, result, maxIterations);
convertPointsSequenceFileToText(conf, fs, result, result + "_text");
fs.close();
return 0;
}
public static void main(String[] args) throws Exception {
KMeansDriver drv = new KMeansDriver();
drv.getConf().set("mapreduce.framework.name", "local");
int exitCode = ToolRunner.run(drv, args);
System.exit(exitCode);
}
}