/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.classifier;
import static org.junit.Assert.assertEquals;
import hivemall.model.PredictionResult;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.junit.Test;
public class PassiveAggressiveUDTFTest {
@Test
public void testInitialize() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
/* test for INT_TYPE_NAME feature */
StructObjectInspector intListSOI = udtf.initialize(new ObjectInspector[] {intListOI, intOI});
assertEquals("struct<feature:int,weight:float>", intListSOI.getTypeName());
/* test for STRING_TYPE_NAME feature */
ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
StructObjectInspector stringListSOI = udtf.initialize(new ObjectInspector[] {stringListOI,
intOI});
assertEquals("struct<feature:string,weight:float>", stringListSOI.getTypeName());
/* test for BIGINT_TYPE_NAME feature */
ObjectInspector longOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
ListObjectInspector longListOI = ObjectInspectorFactory.getStandardListObjectInspector(longOI);
StructObjectInspector longListSOI = udtf.initialize(new ObjectInspector[] {longListOI,
intOI});
assertEquals("struct<feature:bigint,weight:float>", longListSOI.getTypeName());
}
@Test
public void testTrain() throws HiveException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
udtf.initialize(new ObjectInspector[] {intListOI, intOI});
/* train weights by List<Object> */
List<Integer> features1 = new ArrayList<Integer>();
features1.add(1);
features1.add(2);
features1.add(3);
udtf.train(features1, 1);
/* check weights */
assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
assertEquals(0.3333333f, udtf.model.get(3).get(), 1e-5f);
/* train weights by Object[] */
List<?> features2 = (List<?>) intListOI.getList(new Object[] {3, 4, 5});
udtf.train(features2, 1);
/* check weights */
assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
assertEquals(0.5555555f, udtf.model.get(3).get(), 1e-5f);
assertEquals(0.2222222f, udtf.model.get(4).get(), 1e-5f);
assertEquals(0.2222222f, udtf.model.get(5).get(), 1e-5f);
}
@Test
public void testEta() {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF();
float loss = 0.1f;
PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate1 = 2.0f;
assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);
PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
float expectedLearningRate2 = 10.0f;
assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
@Test
public void testPA1Eta() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 3.0");
/* do initialize() with aggressiveness parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});
float loss = 0.1f;
PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate1 = 2.0f;
assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);
PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
float expectedLearningRate2 = 3.0f;
assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
@Test
public void testPA1EtaDefaultParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
udtf.initialize(new ObjectInspector[] {intListOI, intOI});
float loss = 0.1f;
PredictionResult margin = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate = 1.0f;
assertEquals(expectedLearningRate, udtf.eta(loss, margin), 1e-5f);
}
@Test
public void testPA1TrainWithoutParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
/* define aggressive parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI});
/* train weights */
List<?> features = (List<?>) intListOI.getList(new Object[] {1, 2, 3});
udtf.train(features, 1);
/* check weights */
assertEquals(0.3333333f, udtf.model.get(1).get(), 1e-5f);
assertEquals(0.3333333f, udtf.model.get(2).get(), 1e-5f);
assertEquals(0.3333333f, udtf.model.get(3).get(), 1e-5f);
}
@Test
public void testPA1TrainWithParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA1();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 0.1");
/* define aggressive parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});
/* train weights */
List<?> features = (List<?>) intListOI.getList(new Object[] {1, 2, 3});
udtf.train(features, 1);
/* check weights */
assertEquals(0.1000000f, udtf.model.get(1).get(), 1e-5f);
assertEquals(0.1000000f, udtf.model.get(2).get(), 1e-5f);
assertEquals(0.1000000f, udtf.model.get(3).get(), 1e-5f);
}
@Test
public void testPA2EtaWithoutParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
/* do initialize() with aggressiveness parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI});
float loss = 0.1f;
PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate1 = 0.1818181f;
assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);
PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
float expectedLearningRate2 = 0.1960784f;
assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
@Test
public void testPA2EtaWithParameter() throws UDFArgumentException {
PassiveAggressiveUDTF udtf = new PassiveAggressiveUDTF.PA2();
ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);
ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-c 3.0");
/* do initialize() with aggressiveness parameter */
udtf.initialize(new ObjectInspector[] {intListOI, intOI, param});
float loss = 0.1f;
PredictionResult margin1 = new PredictionResult(0.5f).squaredNorm(0.05f);
float expectedLearningRate1 = 0.4615384f;
assertEquals(expectedLearningRate1, udtf.eta(loss, margin1), 1e-5f);
PredictionResult margin2 = new PredictionResult(0.5f).squaredNorm(0.01f);
float expectedLearningRate2 = 0.5660377f;
assertEquals(expectedLearningRate2, udtf.eta(loss, margin2), 1e-5f);
}
}