/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import com.google.common.io.Closeables;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.examples.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Test;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class TrainLogisticTest extends MahoutTestCase {
@Test
public void example13_1() throws Exception {
String outputFile = getTestTempFile("model").getAbsolutePath();
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw, true);
TrainLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--output", outputFile,
"--target", "color", "--categories", "2",
"--predictors", "x", "y",
"--types", "numeric",
"--features", "20",
"--passes", "100",
"--rate", "50"
}, pw);
String trainOut = sw.toString();
assertTrue(trainOut.contains("x -0.7"));
assertTrue(trainOut.contains("y -0.4"));
LogisticModelParameters lmp = TrainLogistic.getParameters();
assertEquals(1.0e-4, lmp.getLambda(), 1.0e-9);
assertEquals(20, lmp.getNumFeatures());
assertTrue(lmp.useBias());
assertEquals("color", lmp.getTargetVariable());
CsvRecordFactory csv = lmp.getCsvRecordFactory();
assertEquals("[1, 2]", Sets.newTreeSet(csv.getTargetCategories()).toString());
assertEquals("[Intercept Term, x, y]", Sets.newTreeSet(csv.getPredictors()).toString());
// verify model by building dissector
AbstractVectorClassifier model = TrainLogistic.getModel();
List<String> data = Resources.readLines(Resources.getResource("donut.csv"), Charsets.UTF_8);
Map<String, Double> expectedValues = ImmutableMap.of("x", -0.7, "y", -0.43, "Intercept Term", -0.15);
verifyModel(lmp, csv, data, model, expectedValues);
// test saved model
InputStream in = new FileInputStream(new File(outputFile));
try {
LogisticModelParameters lmpOut = LogisticModelParameters.loadFrom(in);
CsvRecordFactory csvOut = lmpOut.getCsvRecordFactory();
csvOut.firstLine(data.get(0));
OnlineLogisticRegression lrOut = lmpOut.createRegression();
verifyModel(lmpOut, csvOut, data, lrOut, expectedValues);
} finally {
Closeables.closeQuietly(in);
}
sw = new StringWriter();
pw = new PrintWriter(sw, true);
RunLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--model", outputFile,
"--auc",
"--confusion"
}, pw);
trainOut = sw.toString();
assertTrue(trainOut.contains("AUC = 0.57"));
assertTrue(trainOut.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
}
@Test
public void example13_2() throws Exception {
String outputFile = getTestTempFile("model").getAbsolutePath();
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw, true);
TrainLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--output", outputFile,
"--target", "color",
"--categories", "2",
"--predictors", "x", "y", "a", "b", "c",
"--types", "numeric",
"--features", "20",
"--passes", "100",
"--rate", "50"
}, pw);
String trainOut = sw.toString();
assertTrue(trainOut.contains("a 0."));
assertTrue(trainOut.contains("b -1."));
assertTrue(trainOut.contains("c -25."));
sw = new StringWriter();
pw = new PrintWriter(sw, true);
RunLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--model", outputFile,
"--auc",
"--confusion"
}, pw);
trainOut = sw.toString();
assertTrue(trainOut.contains("AUC = 1.00"));
sw = new StringWriter();
pw = new PrintWriter(sw, true);
RunLogistic.mainToOutput(new String[]{
"--input", "donut-test.csv",
"--model", outputFile,
"--auc",
"--confusion"
}, pw);
trainOut = sw.toString();
assertTrue(trainOut.contains("AUC = 0.9"));
}
private static void verifyModel(LogisticModelParameters lmp,
RecordFactory csv,
List<String> data,
AbstractVectorClassifier model,
Map<String, Double> expectedValues) {
ModelDissector md = new ModelDissector();
for (String line : data.subList(1, data.size())) {
Vector v = new DenseVector(lmp.getNumFeatures());
csv.getTraceDictionary().clear();
csv.processLine(line, v);
md.update(v, csv.getTraceDictionary(), model);
}
// check right variables are present
List<ModelDissector.Weight> weights = md.summary(10);
Set<String> expected = Sets.newHashSet(expectedValues.keySet());
for (ModelDissector.Weight weight : weights) {
assertTrue(expected.remove(weight.getFeature()));
assertEquals(expectedValues.get(weight.getFeature()), weight.getWeight(), 0.1);
}
assertEquals(0, expected.size());
}
}