/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.Locale; import org.apache.hadoop.conf.Configuration; import org.apache.mahout.common.parameters.Parameter; import org.apache.mahout.math.NamedVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.math.function.Functions; import org.apache.mahout.math.function.SquareRootFunction; public abstract class AbstractCluster implements Cluster { // cluster persistent state private int id; private long numPoints; private Vector center; private Vector radius; protected AbstractCluster() {} protected AbstractCluster(Vector point, int id2) { this.setNumPoints(0); this.setCenter(new RandomAccessSparseVector(point)); this.setRadius(point.like()); this.id = id2; } protected AbstractCluster(Vector center2, Vector radius2, int id2) { this.setNumPoints(0); this.setCenter(new RandomAccessSparseVector(center2)); this.setRadius(new RandomAccessSparseVector(radius2)); this.id = id2; } @Override public void configure(Configuration job) { // nothing to do } @Override public Collection<Parameter<?>> getParameters() { return Collections.emptyList(); } @Override public void createParameters(String prefix, Configuration jobConf) { // nothing to do } /** * @param id * the id to set */ protected void setId(int id) { this.id = id; } /** * @param l * the numPoints to set */ protected void setNumPoints(long l) { this.numPoints = l; } /** * @param center * the center to set */ protected void setCenter(Vector center) { this.center = center; } /** * @param radius * the radius to set */ protected void setRadius(Vector radius) { this.radius = radius; } // the observation statistics, initialized by the first observation private double s0; private Vector s1; private Vector s2; /** * @return the s0 */ protected double getS0() { return s0; } /** * @return the s1 */ protected Vector getS1() { return s1; } /** * @return the s2 */ protected Vector getS2() { return s2; } @Override public void observe(Model<VectorWritable> x) { AbstractCluster cl = (AbstractCluster) x; setS0(getS0() + cl.getS0()); setS1(getS1().plus(cl.getS1())); setS2(getS2().plus(cl.getS2())); } public void observe(ClusterObservations observations) { setS0(getS0() + observations.getS0()); if (getS1() == null) { setS1(observations.getS1().clone()); } else { getS1().assign(observations.getS1(), Functions.PLUS); } if (getS2() == null) { setS2(observations.getS2().clone()); } else { getS2().assign(observations.getS2(), Functions.PLUS); } } @Override public void observe(VectorWritable x) { observe(x.get()); } @Override public void observe(VectorWritable x, double weight) { observe(x.get(), weight); } public void observe(Vector x, double weight) { if (weight == 1.0) { observe(x); } else { setS0(getS0() + weight); Vector weightedX = x.times(weight); if (getS1() == null) { setS1(weightedX); } else { getS1().assign(weightedX, Functions.PLUS); } Vector x2 = x.times(x).times(weight); if (getS2() == null) { setS2(x2); } else { getS2().assign(x2, Functions.PLUS); } } } public void observe(Vector x) { setS0(getS0() + 1); if (getS1() == null) { setS1(x.clone()); } else { getS1().assign(x, Functions.PLUS); } Vector x2 = x.times(x); if (getS2() == null) { setS2(x2); } else { getS2().assign(x2, Functions.PLUS); } } @Override public long getNumPoints() { return numPoints; } public ClusterObservations getObservations() { return new ClusterObservations(getS0(), getS1(), getS2()); } @Override public void computeParameters() { if (getS0() == 0) { return; } setNumPoints((int) getS0()); setCenter(getS1().divide(getS0())); // compute the component stds if (getS0() > 1) { setRadius(getS2().times(getS0()).minus(getS1().times(getS1())) .assign(new SquareRootFunction()).divide(getS0())); } setS0(0); setS1(null); setS2(null); } @Override public void readFields(DataInput in) throws IOException { this.id = in.readInt(); this.setNumPoints(in.readLong()); VectorWritable temp = new VectorWritable(); temp.readFields(in); this.setCenter(temp.get()); temp.readFields(in); this.setRadius(temp.get()); } @Override public void write(DataOutput out) throws IOException { out.writeInt(id); out.writeLong(getNumPoints()); VectorWritable.writeVector(out, getCenter()); VectorWritable.writeVector(out, getRadius()); } @Override public String asFormatString(String[] bindings) { StringBuilder buf = new StringBuilder(50); buf.append(getIdentifier()).append("{n=").append(getNumPoints()); if (getCenter() != null) { buf.append(" c=").append(formatVector(getCenter(), bindings)); } if (getRadius() != null) { buf.append(" r=").append(formatVector(getRadius(), bindings)); } buf.append('}'); return buf.toString(); } public abstract String getIdentifier(); @Override public Vector getCenter() { return center; } @Override public int getId() { return id; } @Override public Vector getRadius() { return radius; } /** * Compute the centroid by averaging the pointTotals * * @return the new centroid */ public Vector computeCentroid() { return getS0() == 0 ? getCenter() : getS1().divide(getS0()); } /** * Return a human-readable formatted string representation of the vector, not * intended to be complete nor usable as an input/output representation */ public static String formatVector(Vector v, String[] bindings) { StringBuilder buf = new StringBuilder(); if (v instanceof NamedVector) { buf.append(((NamedVector) v).getName()).append(" = "); } int nzero = 0; Iterator<Vector.Element> iterateNonZero = v.iterateNonZero(); while (iterateNonZero.hasNext()) { iterateNonZero.next(); nzero++; } // if vector is sparse or if we have bindings, use sparse notation if (nzero < v.size() || bindings != null) { buf.append('['); for (int i = 0; i < v.size(); i++) { double elem = v.get(i); if (elem == 0.0) { continue; } String label; if (bindings != null && (label = bindings[i]) != null) { buf.append(label).append(':'); } else { buf.append(i).append(':'); } buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", "); } } else { buf.append('['); for (int i = 0; i < v.size(); i++) { double elem = v.get(i); buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", "); } } if (buf.length() > 1) { buf.setLength(buf.length() - 2); } buf.append(']'); return buf.toString(); } @Override public long count() { return getNumPoints(); } @Override public boolean isConverged() { // Convergence has no meaning yet, perhaps in subclasses return false; } protected void setS0(double s0) { this.s0 = s0; } protected void setS1(Vector s1) { this.s1 = s1; } protected void setS2(Vector s2) { this.s2 = s2; } }