/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.evaluation; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Map.Entry; import com.google.common.collect.Maps; import com.google.common.io.Closeables; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.clustering.AbstractCluster; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.WeightedVectorWritable; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.ClassUtils; import org.apache.mahout.common.Pair; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.distance.DistanceMeasure; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; import org.apache.mahout.math.VectorWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public final class RepresentativePointsDriver extends AbstractJob { public static final String STATE_IN_KEY = "org.apache.mahout.clustering.stateIn"; public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.measure"; private static final Logger log = LoggerFactory.getLogger(RepresentativePointsDriver.class); private RepresentativePointsDriver() { } public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new RepresentativePointsDriver(), args); } @Override public int run(String[] args) throws ClassNotFoundException, IOException, InterruptedException { addInputOption(); addOutputOption(); addOption("clusteredPoints", "cp", "The path to the clustered points", true); addOption(DefaultOptionCreator.distanceMeasureOption().create()); addOption(DefaultOptionCreator.maxIterationsOption().create()); addOption(DefaultOptionCreator.methodOption().create()); if (parseArguments(args) == null) { return -1; } Path input = getInputPath(); Path output = getOutputPath(); String distanceMeasureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( DefaultOptionCreator.SEQUENTIAL_METHOD); DistanceMeasure measure = ClassUtils.instantiateAs(distanceMeasureClass, DistanceMeasure.class); Path clusteredPoints = new Path(getOption("clusteredPoints")); run(getConf(), input, clusteredPoints, output, measure, maxIterations, runSequential); return 0; } public static void run(Configuration conf, Path clustersIn, Path clusteredPointsIn, Path output, DistanceMeasure measure, int numIterations, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException { Path stateIn = new Path(output, "representativePoints-0"); writeInitialState(stateIn, clustersIn); for (int iteration = 0; iteration < numIterations; iteration++) { log.info("Representative Points Iteration {}", iteration); // point the output to a new directory per iteration Path stateOut = new Path(output, "representativePoints-" + (iteration + 1)); runIteration(conf, clusteredPointsIn, stateIn, stateOut, measure, runSequential); // now point the input to the old output directory stateIn = stateOut; } conf.set(STATE_IN_KEY, stateIn.toString()); conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); } private static void writeInitialState(Path output, Path clustersIn) throws IOException { Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(output.toUri(), conf); for (FileStatus part : fs.listStatus(clustersIn, PathFilters.logsCRCFilter())) { Path inPart = part.getPath(); Path path = new Path(output, inPart.getName()); SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class); try { for (Cluster value : new SequenceFileValueIterable<Cluster>(inPart, true, conf)) { if (log.isDebugEnabled()) { log.debug("C-{}: {}", value.getId(), AbstractCluster.formatVector(value.getCenter(), null)); } writer.append(new IntWritable(value.getId()), new VectorWritable(value.getCenter())); } } finally { Closeables.closeQuietly(writer); } } } private static void runIteration(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, DistanceMeasure measure, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException { if (runSequential) { runIterationSeq(conf, clusteredPointsIn, stateIn, stateOut, measure); } else { runIterationMR(conf, clusteredPointsIn, stateIn, stateOut, measure); } } /** * Run the job using supplied arguments as a sequential process * @param conf * the Configuration to use * @param clusteredPointsIn * the directory pathname for input points * @param stateIn * the directory pathname for input state * @param stateOut * the directory pathname for output state * @param measure * the DistanceMeasure to use */ private static void runIterationSeq(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, DistanceMeasure measure) throws IOException { Map<Integer, List<VectorWritable>> repPoints = RepresentativePointsMapper.getRepresentativePoints(conf, stateIn); Map<Integer, WeightedVectorWritable> mostDistantPoints = Maps.newHashMap(); FileSystem fs = FileSystem.get(clusteredPointsIn.toUri(), conf); for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileDirIterable<IntWritable,WeightedVectorWritable>( clusteredPointsIn, PathType.LIST, PathFilters.logsCRCFilter(), null, true, conf)) { RepresentativePointsMapper.mapPoint( record.getFirst(), record.getSecond(), measure, repPoints, mostDistantPoints); } int part = 0; SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), IntWritable.class, VectorWritable.class); try { for (Entry<Integer, List<VectorWritable>> entry : repPoints.entrySet()) { for (VectorWritable vw : entry.getValue()) { writer.append(new IntWritable(entry.getKey()), vw); } } } finally { Closeables.closeQuietly(writer); } writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), IntWritable.class, VectorWritable.class); try { for (Map.Entry<Integer, WeightedVectorWritable> entry : mostDistantPoints.entrySet()) { writer.append(new IntWritable(entry.getKey()), new VectorWritable(entry.getValue().getVector())); } } finally { Closeables.closeQuietly(writer); } } /** * Run the job using supplied arguments as a Map/Reduce process * @param conf * the Configuration to use * @param input * the directory pathname for input points * @param stateIn * the directory pathname for input state * @param stateOut * the directory pathname for output state * @param measure * the DistanceMeasure to use */ private static void runIterationMR(Configuration conf, Path input, Path stateIn, Path stateOut, DistanceMeasure measure) throws IOException, InterruptedException, ClassNotFoundException { conf.set(STATE_IN_KEY, stateIn.toString()); conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); Job job = new Job(conf, "Representative Points Driver running over input: " + input); job.setJarByClass(RepresentativePointsDriver.class); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(VectorWritable.class); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(WeightedVectorWritable.class); FileInputFormat.setInputPaths(job, input); FileOutputFormat.setOutputPath(job, stateOut); job.setMapperClass(RepresentativePointsMapper.class); job.setReducerClass(RepresentativePointsReducer.class); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.waitForCompletion(true); } }