/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.Random;
public abstract class OnlineBaseTest extends MahoutTestCase {
private Matrix input;
Matrix getInput() {
return input;
}
Vector readStandardData() throws IOException {
// 60 test samples. First column is constant. Second and third are normally distributed from
// either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
// target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
input = readCsv("sgd.csv");
// regenerate the target variable
Vector target = new DenseVector(60);
target.assign(0);
target.viewPart(30, 30).assign(1);
return target;
}
static void train(Matrix input, Vector target, OnlineLearner lr) {
RandomUtils.useTestSeed();
Random gen = RandomUtils.getRandom();
// train on samples in random order (but only one pass)
for (int row : permute(gen, 60)) {
lr.train((int) target.get(row), input.viewRow(row));
}
lr.close();
}
static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
double expected_mean_error, double expected_absolute_error) {
// now test the accuracy
Matrix tmp = lr.classify(input);
// mean(abs(tmp - target))
double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
// max(abs(tmp - target)
double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
assertEquals(0, meanAbsoluteError , expected_mean_error);
assertEquals(0, maxAbsoluteError, expected_absolute_error);
// convenience methods should give the same results
Vector v = lr.classifyScalar(input);
assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
v = lr.classifyFull(input).viewColumn(1);
assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
}
/**
* Permute the integers from 0 ... max-1
*
* @param gen The random number generator to use.
* @param max The number of integers to permute
* @return An array of jumbled integer values
*/
static int[] permute(Random gen, int max) {
int[] permutation = new int[max];
permutation[0] = 0;
for (int i = 1; i < max; i++) {
int n = gen.nextInt(i + 1);
if (n == i) {
permutation[i] = i;
} else {
permutation[i] = permutation[n];
permutation[n] = i;
}
}
return permutation;
}
/**
* Reads a file containing CSV data. This isn't implemented quite the way you might like for a
* real program, but does the job for reading test data. Most notably, it will only read numbers,
* not quoted strings.
*
* @param resourceName Where to get the data.
* @return A matrix of the results.
* @throws IOException If there is an error reading the data
*/
static Matrix readCsv(String resourceName) throws IOException {
Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
List<String> data = CharStreams.readLines(isr);
String first = data.get(0);
data = data.subList(1, data.size());
List<String> values = Lists.newArrayList(onCommas.split(first));
Matrix r = new DenseMatrix(data.size(), values.size());
int column = 0;
Map<String, Integer> labels = Maps.newHashMap();
for (String value : values) {
labels.put(value, column);
column++;
}
r.setColumnLabelBindings(labels);
int row = 0;
for (String line : data) {
column = 0;
values = Lists.newArrayList(onCommas.split(line));
for (String value : values) {
r.set(row, column, Double.parseDouble(value));
column++;
}
row++;
}
return r;
}
}