/** * Licensed to the Apache Software Foundation (ASF) under one or more contributor license * agreements. See the NOTICE file distributed with this work for additional information regarding * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance with the License. You may obtain a * copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package org.apache.mahout.math; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.Iterator; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.io.Writable; import org.apache.mahout.math.Vector.Element; import com.google.common.base.Preconditions; public final class VectorWritable extends Configured implements Writable { public static final int FLAG_DENSE = 0x01; public static final int FLAG_SEQUENTIAL = 0x02; public static final int FLAG_NAMED = 0x04; public static final int FLAG_LAX_PRECISION = 0x08; public static final int NUM_FLAGS = 4; private Vector vector; private boolean writesLaxPrecision; public VectorWritable() {} public VectorWritable(boolean writesLaxPrecision) { setWritesLaxPrecision(writesLaxPrecision); } public VectorWritable(Vector vector) { this.vector = vector; } public VectorWritable(Vector vector, boolean writesLaxPrecision) { this(vector); setWritesLaxPrecision(writesLaxPrecision); } /** * @return {@link org.apache.mahout.math.Vector} that this is to write, or has * just read */ public Vector get() { return vector; } public void set(Vector vector) { this.vector = vector; } /** * @return true if this is allowed to encode {@link org.apache.mahout.math.Vector} * values using fewer bytes, possibly losing precision. In particular this means * that floating point values will be encoded as floats, not doubles. */ public boolean isWritesLaxPrecision() { return writesLaxPrecision; } public void setWritesLaxPrecision(boolean writesLaxPrecision) { this.writesLaxPrecision = writesLaxPrecision; } @Override public void write(DataOutput out) throws IOException { writeVector(out, this.vector, this.writesLaxPrecision); } @Override public void readFields(DataInput in) throws IOException { int flags = in.readByte(); int size = Varint.readUnsignedVarInt(in); readFields(in, (byte) flags, size); } private void readFields(DataInput in, byte flags, int size) throws IOException { Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2)); boolean dense = (flags & FLAG_DENSE) != 0; boolean sequential = (flags & FLAG_SEQUENTIAL) != 0; boolean named = (flags & FLAG_NAMED) != 0; boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0; Vector v; if (dense) { double[] values = new double[size]; for (int i = 0; i < size; i++) { values[i] = laxPrecision ? in.readFloat() : in.readDouble(); } v = new DenseVector(values); } else { int numNonDefaultElements = Varint.readUnsignedVarInt(in); v = sequential ? new SequentialAccessSparseVector(size, numNonDefaultElements) : new RandomAccessSparseVector(size, numNonDefaultElements); if (sequential) { int lastIndex = 0; for (int i = 0; i < numNonDefaultElements; i++) { int delta = Varint.readUnsignedVarInt(in); int index = lastIndex + delta; lastIndex = index; double value = laxPrecision ? in.readFloat() : in.readDouble(); v.setQuick(index, value); } } else { for (int i = 0; i < numNonDefaultElements; i++) { int index = Varint.readUnsignedVarInt(in); double value = laxPrecision ? in.readFloat() : in.readDouble(); v.setQuick(index, value); } } } if (named) { String name = in.readUTF(); v = new NamedVector(v, name); } vector = v; } /** Write the vector to the output */ public static void writeVector(DataOutput out, Vector vector) throws IOException { writeVector(out, vector, false); } public static byte flags(Vector vector, boolean laxPrecision) { boolean dense = vector.isDense(); boolean sequential = vector.isSequentialAccess(); boolean named = vector instanceof NamedVector; return (byte) ((dense ? FLAG_DENSE : 0) | (sequential ? FLAG_SEQUENTIAL : 0) | (named ? FLAG_NAMED : 0) | (laxPrecision ? FLAG_LAX_PRECISION : 0)); } /** Write out type information and size of the vector */ public static void writeVectorFlagsAndSize(DataOutput out, byte flags, int size) throws IOException { out.writeByte(flags); Varint.writeUnsignedVarInt(size, out); } public static void writeVector(DataOutput out, Vector vector, boolean laxPrecision) throws IOException { byte flags = flags(vector, laxPrecision); writeVectorFlagsAndSize(out, flags, vector.size()); writeVectorContents(out, vector, flags); } /** Write out contents of the vector */ public static void writeVectorContents(DataOutput out, Vector vector, byte flags) throws IOException { boolean dense = (flags & FLAG_DENSE) != 0; boolean sequential = (flags & FLAG_SEQUENTIAL) != 0; boolean named = (flags & FLAG_NAMED) != 0; boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0; if (dense) { for (Element element : vector.all()) { if (laxPrecision) { out.writeFloat((float) element.get()); } else { out.writeDouble(element.get()); } } } else { Varint.writeUnsignedVarInt(vector.getNumNonZeroElements(), out); Iterator<Element> iter = vector.nonZeroes().iterator(); if (sequential) { int lastIndex = 0; while (iter.hasNext()) { Element element = iter.next(); if (element.get() == 0) { continue; } int thisIndex = element.index(); // Delta-code indices: Varint.writeUnsignedVarInt(thisIndex - lastIndex, out); lastIndex = thisIndex; if (laxPrecision) { out.writeFloat((float) element.get()); } else { out.writeDouble(element.get()); } } } else { while (iter.hasNext()) { Element element = iter.next(); if (element.get() == 0) { // TODO(robinanil): Fix the damn iterator for the zero element. continue; } Varint.writeUnsignedVarInt(element.index(), out); if (laxPrecision) { out.writeFloat((float) element.get()); } else { out.writeDouble(element.get()); } } } } if (named) { String name = ((NamedVector) vector).getName(); out.writeUTF(name == null ? "" : name); } } public static Vector readVector(DataInput in) throws IOException { VectorWritable v = new VectorWritable(); v.readFields(in); return v.get(); } public static Vector readVector(DataInput in, byte vectorFlags, int size) throws IOException { VectorWritable v = new VectorWritable(); v.readFields(in, vectorFlags, size); return v.get(); } public static VectorWritable merge(Iterator<VectorWritable> vectors) { return new VectorWritable(mergeToVector(vectors)); } public static Vector mergeToVector(Iterator<VectorWritable> vectors) { Vector accumulator = vectors.next().get(); while (vectors.hasNext()) { VectorWritable v = vectors.next(); if (v != null) { for (Element nonZeroElement : v.get().nonZeroes()) { accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get()); } } } return accumulator; } @Override public boolean equals(Object o) { return o instanceof VectorWritable && vector.equals(((VectorWritable) o).get()); } @Override public int hashCode() { return vector.hashCode(); } @Override public String toString() { return vector.toString(); } }