/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.fm; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.junit.Assert; import org.junit.Test; public class FieldAwareFactorizationMachineUDTFTest { private static final boolean DEBUG = false; private static final int ITERATIONS = 50; private static final int MAX_LINES = 200; @Test public void testSGD() throws HiveException, IOException { runTest("Pure SGD test", "-classification -factors 10 -w0 -seed 43 -disable_adagrad -disable_ftrl", 0.60f); } @Test public void testSGDWithFTRL() throws HiveException, IOException { runTest("SGD w/ FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_adagrad", 0.60f); } @Test public void testAdaGradNoCoeff() throws HiveException, IOException { runTest("AdaGrad No Coeff test", "-classification -factors 10 -w0 -seed 43 -no_coeff", 0.30f); } @Test public void testAdaGradNoFTRL() throws HiveException, IOException { runTest("AdaGrad w/o FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_ftrl", 0.30f); } @Test public void testAdaGradDefault() throws HiveException, IOException { runTest("AdaGrad DEFAULT (adagrad for V + FTRL for W)", "-classification -factors 10 -w0 -seed 43", 0.30f); } private static void runTest(String testName, String testOptions, float lossThreshold) throws IOException, HiveException { println(testName); FieldAwareFactorizationMachineUDTF udtf = new FieldAwareFactorizationMachineUDTF(); ObjectInspector[] argOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, testOptions)}; udtf.initialize(argOIs); FieldAwareFactorizationMachineModel model = udtf.initModel(udtf._params); Assert.assertTrue("Actual class: " + model.getClass().getName(), model instanceof FFMStringFeatureMapModel); double loss = 0.d; double cumul = 0.d; for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) { BufferedReader data = new BufferedReader(new InputStreamReader( FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream("bigdata.tr.txt"))); loss = udtf._cvState.getCumulativeLoss(); int lines = 0; for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, ++lines) { //gather features in current line final String input = data.readLine(); if (input == null) { break; } ArrayList<String> featureStrings = new ArrayList<String>(); ArrayList<StringFeature> features = new ArrayList<StringFeature>(); //make StringFeature for each word = data point String remaining = input; int wordCut = remaining.indexOf(' '); while (wordCut != -1) { featureStrings.add(remaining.substring(0, wordCut)); remaining = remaining.substring(wordCut + 1); wordCut = remaining.indexOf(' '); } int end = featureStrings.size(); double y = Double.parseDouble(featureStrings.get(0)); if (y == 0) { y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} } for (int wordNumber = 1; wordNumber < end; ++wordNumber) { String entireFeature = featureStrings.get(wordNumber); int featureCut = StringUtils.ordinalIndexOf(entireFeature, ":", 2); String feature = entireFeature.substring(0, featureCut); double value = Double.parseDouble(entireFeature.substring(featureCut + 1)); features.add(new StringFeature(feature, value)); } udtf.process(new Object[] {toStringArray(features), y}); } cumul = udtf._cvState.getCumulativeLoss(); loss = (cumul - loss) / lines; println(trainingIteration + " " + loss + " " + cumul / (trainingIteration * lines)); data.close(); } println("model size=" + udtf._model.getSize()); Assert.assertTrue("Last loss was greater than expected: " + loss, loss < lossThreshold); } private static String[] toStringArray(ArrayList<StringFeature> x) { final int size = x.size(); final String[] ret = new String[size]; for (int i = 0; i < size; i++) { ret[i] = x.get(i).toString(); } return ret; } private static void println(String line) { if (DEBUG) { System.out.println(line); } } }