/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.lens.ml; import java.io.File; import java.net.URI; import java.util.*; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.Application; import javax.ws.rs.core.UriBuilder; import org.apache.lens.client.LensClient; import org.apache.lens.client.LensClientConfig; import org.apache.lens.client.LensMLClient; import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo; import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo; import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo; import org.apache.lens.ml.algo.spark.svm.SVMAlgo; import org.apache.lens.ml.impl.MLTask; import org.apache.lens.ml.impl.MLUtils; import org.apache.lens.ml.server.MLApp; import org.apache.lens.ml.server.MLServiceResource; //import org.apache.lens.server.LensJerseyTest; import org.apache.lens.server.api.LensConfConstants; import org.apache.lens.server.query.QueryServiceResource; import org.apache.lens.server.session.SessionResource; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; import org.glassfish.jersey.client.ClientConfig; import org.glassfish.jersey.media.multipart.MultiPartFeature; import org.testng.Assert; //import org.testng.annotations.AfterTest; //import org.testng.annotations.BeforeMethod; //import org.testng.annotations.BeforeTest; //import org.testng.annotations.Test; import lombok.extern.slf4j.Slf4j; @Slf4j //@Test public class TestMLResource { //extends LensJerseyTest { private static final String TEST_DB = "default"; private WebTarget mlTarget; private LensMLClient mlClient; protected int getTestPort() { return 10058; } //@Override protected Application configure() { return new MLApp(SessionResource.class, QueryServiceResource.class); } //@Override protected void configureClient(ClientConfig config) { config.register(MultiPartFeature.class); } //@Override protected URI getBaseUri() { return UriBuilder.fromUri("http://localhost/").port(getTestPort()).path("/lensapi").build(); } //@BeforeTest public void setUp() throws Exception { //super.setUp(); Hive hive = Hive.get(new HiveConf()); Database db = new Database(); db.setName(TEST_DB); hive.createDatabase(db, true); LensClientConfig lensClientConfig = new LensClientConfig(); lensClientConfig.setLensDatabase(TEST_DB); lensClientConfig.set(LensConfConstants.SERVER_BASE_URL, "http://localhost:" + getTestPort() + "/lensapi"); LensClient client = new LensClient(lensClientConfig); mlClient = new LensMLClient(client); } //@AfterTest public void tearDown() throws Exception { //super.tearDown(); Hive hive = Hive.get(new HiveConf()); try { hive.dropDatabase(TEST_DB); } catch (Exception exc) { // Ignore drop db exception log.error("Exception while dropping database.", exc); } mlClient.close(); } //@BeforeMethod public void setMLTarget() { //mlTarget = target().path("ml"); } //@Test public void testMLResourceUp() throws Exception { String mlUpMsg = mlTarget.request().get(String.class); Assert.assertEquals(mlUpMsg, MLServiceResource.ML_UP_MESSAGE); } //@Test public void testGetAlgos() throws Exception { List<String> algoNames = mlClient.getAlgorithms(); Assert.assertNotNull(algoNames); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)), MLUtils.getAlgoName(NaiveBayesAlgo.class)); Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(SVMAlgo.class)), MLUtils.getAlgoName(SVMAlgo.class)); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)), MLUtils.getAlgoName(LogisticRegressionAlgo.class)); Assert.assertTrue( algoNames.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)), MLUtils.getAlgoName(DecisionTreeAlgo.class)); } //@Test public void testGetAlgoParams() throws Exception { Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils .getAlgoName(DecisionTreeAlgo.class)); Assert.assertNotNull(params); Assert.assertFalse(params.isEmpty()); for (String key : params.keySet()) { log.info("## Param " + key + " help = " + params.get(key)); } } //@Test public void trainAndEval() throws Exception { log.info("Starting train & eval"); final String algoName = MLUtils.getAlgoName(NaiveBayesAlgo.class); HiveConf conf = new HiveConf(); String tableName = "naivebayes_training_table"; String sampleDataFilePath = "data/naive_bayes/naive_bayes_train.data"; File sampleDataFile = new File(sampleDataFilePath); URI sampleDataFileURI = sampleDataFile.toURI(); String labelColumn = "label"; String[] features = { "feature_1", "feature_2", "feature_3" }; String outputTable = "naivebayes_eval_table"; log.info("Creating training table from file " + sampleDataFileURI.toString()); Map<String, String> tableParams = new HashMap<String, String>(); try { ExampleUtils.createTable(conf, TEST_DB, tableName, sampleDataFileURI.toString(), labelColumn, tableParams, features); } catch (HiveException exc) { log.error("Hive exception encountered.", exc); } MLTask.Builder taskBuilder = new MLTask.Builder(); taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn) .outputTable(outputTable).client(mlClient).trainingTable(tableName); // Add features taskBuilder.addFeatureColumn("feature_1").addFeatureColumn("feature_2") .addFeatureColumn("feature_3"); MLTask task = taskBuilder.build(); log.info("Created task " + task.toString()); task.run(); Assert.assertEquals(task.getTaskState(), MLTask.State.SUCCESSFUL); String firstModelID = task.getModelID(); String firstReportID = task.getReportID(); Assert.assertNotNull(firstReportID); Assert.assertNotNull(firstModelID); taskBuilder = new MLTask.Builder(); taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn) .outputTable(outputTable).client(mlClient).trainingTable(tableName); taskBuilder.addFeatureColumn("feature_1").addFeatureColumn("feature_2") .addFeatureColumn("feature_3"); MLTask anotherTask = taskBuilder.build(); log.info("Created second task " + anotherTask.toString()); anotherTask.run(); String secondModelID = anotherTask.getModelID(); String secondReportID = anotherTask.getReportID(); Assert.assertNotNull(secondModelID); Assert.assertNotNull(secondReportID); Hive metastoreClient = Hive.get(conf); Table outputHiveTable = metastoreClient.getTable(outputTable); List<Partition> partitions = metastoreClient.getPartitions(outputHiveTable); Assert.assertNotNull(partitions); int i = 0; Set<String> partReports = new HashSet<String>(); for (Partition part : partitions) { log.info("@@PART#" + i + " " + part.getSpec().toString()); partReports.add(part.getSpec().get("part_testid")); } // Verify partitions created for each run Assert.assertTrue(partReports.contains(firstReportID), firstReportID + " first partition not there"); Assert.assertTrue(partReports.contains(secondReportID), secondReportID + " second partition not there"); log.info("Completed task run"); } }