/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.hadoop;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.Iterator;
import java.util.Locale;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix.MatrixEntryWritable;
import org.easymock.IArgumentMatcher;
import org.easymock.EasyMock;
import org.junit.Assert;
/**
* a collection of small helper methods useful for unit-testing mathematical operations
*/
public final class MathHelper {
private MathHelper() {}
/**
* applies an {@link IArgumentMatcher} to {@link MatrixEntryWritable}s
*/
public static MatrixEntryWritable matrixEntryMatches(final int row, final int col, final double value) {
EasyMock.reportMatcher(new IArgumentMatcher() {
@Override
public boolean matches(Object argument) {
if (argument instanceof MatrixEntryWritable) {
MatrixEntryWritable entry = (MatrixEntryWritable) argument;
return row == entry.getRow()
&& col == entry.getCol()
&& Math.abs(value - entry.getVal()) <= MahoutTestCase.EPSILON;
}
return false;
}
@Override
public void appendTo(StringBuffer buffer) {
buffer.append("MatrixEntry[row=").append(row)
.append(",col=").append(col)
.append(",value=").append(value).append(']');
}
});
return null;
}
/**
* convenience method to create a {@link MatrixEntryWritable}
*/
public static MatrixEntryWritable matrixEntry(int row, int col, double value) {
MatrixEntryWritable entry = new MatrixEntryWritable();
entry.setRow(row);
entry.setCol(col);
entry.setVal(value);
return entry;
}
/**
* convenience method to create a {@link Vector.Element}
*/
public static Vector.Element elem(int index, double value) {
return new ElementToCheck(index, value);
}
/**
* a simple implementation of {@link Vector.Element}
*/
static class ElementToCheck implements Vector.Element {
private final int index;
private double value;
ElementToCheck(int index, double value) {
this.index = index;
this.value = value;
}
@Override
public double get() {
return value;
}
@Override
public int index() {
return index;
}
@Override
public void set(double value) {
this.value = value;
}
}
/**
* applies an {@link IArgumentMatcher} to a {@link VectorWritable} that checks whether all elements are included
*/
public static VectorWritable vectorMatches(final Vector.Element... elements) {
EasyMock.reportMatcher(new IArgumentMatcher() {
@Override
public boolean matches(Object argument) {
if (argument instanceof VectorWritable) {
Vector v = ((VectorWritable) argument).get();
return consistsOf(v, elements);
}
return false;
}
@Override
public void appendTo(StringBuffer buffer) {}
});
return null;
}
/**
* checks whether the {@link Vector} is equivalent to the set of {@link Vector.Element}s
*/
public static boolean consistsOf(Vector vector, Vector.Element... elements) {
if (elements.length != numberOfNoNZeroNonNaNElements(vector)) {
return false;
}
for (Vector.Element element : elements) {
if (Math.abs(element.get() - vector.get(element.index())) > MahoutTestCase.EPSILON) {
return false;
}
}
return true;
}
/**
* returns the number of elements in the {@link Vector} that are neither 0 nor NaN
*/
public static int numberOfNoNZeroNonNaNElements(Vector vector) {
int elementsInVector = 0;
Iterator<Vector.Element> vectorIterator = vector.iterateNonZero();
while (vectorIterator.hasNext()) {
Vector.Element currentElement = vectorIterator.next();
if (!Double.isNaN(currentElement.get())) {
elementsInVector++;
}
}
return elementsInVector;
}
/**
* read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable>
*/
public static Matrix readMatrix(Configuration conf, Path path, int rows, int columns) {
boolean readOneRow = false;
Matrix matrix = new DenseMatrix(rows, columns);
for (Pair<IntWritable,VectorWritable> record :
new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) {
IntWritable key = record.getFirst();
VectorWritable value = record.getSecond();
readOneRow = true;
int row = key.get();
Iterator<Vector.Element> elementsIterator = value.get().iterateNonZero();
while (elementsIterator.hasNext()) {
Vector.Element element = elementsIterator.next();
matrix.set(row, element.index(), element.get());
}
}
if (!readOneRow) {
throw new IllegalStateException("Not a single row read!");
}
return matrix;
}
/**
* write a two-dimensional double array to an SequenceFile<IntWritable,VectorWritable>
*/
public static void writeDistributedRowMatrix(double[][] entries, FileSystem fs, Configuration conf, Path path)
throws IOException {
SequenceFile.Writer writer = null;
try {
writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
for (int n = 0; n < entries.length; n++) {
Vector v = new RandomAccessSparseVector(entries[n].length);
for (int m = 0; m < entries[n].length; m++) {
v.setQuick(m, entries[n][m]);
}
writer.append(new IntWritable(n), new VectorWritable(v));
}
} finally {
Closeables.closeQuietly(writer);
}
}
public static void assertMatrixEquals(Matrix expected, Matrix actual) {
Assert.assertEquals(expected.numRows(), actual.numRows());
Assert.assertEquals(actual.numCols(), actual.numCols());
for (int row = 0; row < expected.numRows(); row++) {
for (int col = 0; col < expected.numCols(); col ++) {
Assert.assertEquals("Non-matching values in [" + row + ',' + col + ']',
expected.get(row, col), actual.get(row, col), MahoutTestCase.EPSILON);
}
}
}
public static String nice(Vector v) {
if (!v.isSequentialAccess()) {
v = new DenseVector(v);
}
DecimalFormat df = new DecimalFormat("0.00", DecimalFormatSymbols.getInstance(Locale.ENGLISH));
StringBuilder buffer = new StringBuilder("[");
String separator = "";
for (Vector.Element e : v) {
buffer.append(separator);
if (Double.isNaN(e.get())) {
buffer.append(" - ");
} else {
if (e.get() >= 0) {
buffer.append(' ');
}
buffer.append(df.format(e.get()));
}
separator = "\t";
}
buffer.append(" ]");
return buffer.toString();
}
public static String nice(Matrix matrix) {
StringBuilder info = new StringBuilder();
for (int n = 0; n < matrix.numRows(); n++) {
info.append(nice(matrix.viewRow(n))).append('\n');
}
return info.toString();
}
}